TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包, 融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架, 用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。
可以通过ACL Anthology或arXiv pre-print查看我们的论文。
Mar 1, 2021
BERT-EMD示例与自定义distiller
MNLI示例更新
Nov 11, 2020
版本更新至0.2.1:
灵活性提升:支持为教师模型和学生模型输入各自独立的batch,不再要求教师模型和学生模型的输入相同。可用于词表不同的模型之间(例如从RoBERTa到BERT)的蒸馏。
蒸馏加速:支持用户自定义传入教师模型的输出缓存,避免教师模型的重复前向计算,加速蒸馏过程。
以上特性的详细说明可参见 Feed Different batches to Student and Teacher, Feed Cached Values
增加了MultiTaskDistiller
对中间层匹配损失的支持。
Tensorboard中记录更详细的损失函数(KD loss, hard label loss, matching losses...)。
更新细节参见 releases。
Aug 27, 2020
哈工大讯飞联合实验室在通用自然语言理解评测GLUE中荣登榜首,查看GLUE榜单,新闻。
Aug 24, 2020
MultiTaskDistiller
以及训练循环中的若干bug。Jul 29, 2020
TrainingConfig
中传入相应的local_rank
以启用。详细设置参见TraningConfig
的说明。Jul 14, 2020
TrainingConfig
中设置fp16=True
启用。详细设置参见TraningConfig
的说明。TrainingConfig
中增加了data_parallel
选项,使得数据并行与混合精度训练可同时启用。Apr 26, 2020
Apr 22, 2020
Mar 17, 2020
Mar 11, 2020
TrainingConfig
和distiller的train
方法),细节参见 releases。Mar 2, 2020
章节 | 内容 |
---|---|
简介 | TextBrewer简介 |
安装 | 安装方法介绍 |
工作流程 | TextBrewer整体工作流程 |
快速开始 | 举例展示TextBrewer用法:BERT-base蒸馏至3层BERT |
蒸馏效果 | 中文、英文典型数据集上的蒸馏效果展示 |
核心概念 | TextBrewer中的核心概念介绍 |
FAQ | 常见问题解答 |
引用 | TextBrewer参考引用 |
已知问题 | 尚未解决的问题 |
关注我们 | - |
TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。
主要特点:
TextBrewer目前支持的知识蒸馏技术有:
TextBrewer的主要功能与模块分为3块:
用户需要准备:
在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。
详细的API可参见 完整文档。
pip install textbrewer
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer
Stage 1 : 蒸馏之前的准备工作:
Stage 2 : 使用TextBrewer蒸馏:
TrainingConfig
)和蒸馏配置(DistillationConfig
),初始化distiller
以蒸馏BERT-base到3层BERT为例展示TextBrewer用法。
在开始蒸馏之前准备:
teacher_model
(BERT-base),待训练学生模型student_model
(3-layer BERT)dataloader
,优化器optimizer
,学习率调节器类或者构造函数scheduler_class
和构造用的参数字典 scheduler_args
使用TextBrewer蒸馏:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
# 展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)
print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)
# 定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):
# model输出的第二、三个元素分别是logits和hidden states
return {'logits': model_outputs[1], 'hidden': model_outputs[2]}
# 蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(
intermediate_matches=[
{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()
#初始化distiller
distiller = GeneralDistiller(
train_config=train_config, distill_config = distill_config,
model_T = teacher_model, model_S = student_model,
adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
# 开始蒸馏
with distiller:
distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
更多的示例可参见examples
文件夹:
我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。
我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。
Model | #Layers | Hidden size | Feed-forward size | #Params | Relative size |
---|---|---|---|---|---|
BERT-base-cased (教师) | 12 | 768 | 3072 | 108M | 100% |
T6 (学生) | 6 | 768 | 3072 | 65M | 60% |
T3 (学生) | 3 | 768 | 3072 | 44M | 41% |
T3-small (学生) | 3 | 384 | 1536 | 17M | 16% |
T4-Tiny (学生) | 4 | 312 | 1200 | 14M | 13% |
T12-nano (学生) | 12 | 256 | 1024 | 17M | 16% |
BiGRU (学生) | - | 768 | - | 31M | 29% |
Model | #Layers | Hidden size | Feed-forward size | #Params | Relative size |
---|---|---|---|---|---|
RoBERTa-wwm-ext (教师) | 12 | 768 | 3072 | 102M | 100% |
Electra-base (教师) | 12 | 768 | 3072 | 102M | 100% |
T3 (学生) | 3 | 768 | 3072 | 38M | 37% |
T3-small (学生) | 3 | 384 | 1536 | 14M | 14% |
T4-Tiny (学生) | 4 | 312 | 1200 | 11M | 11% |
Electra-small (学生) | 12 | 256 | 1024 | 12M | 12% |
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
# 其他参数为默认值
不同的模型用的matches
我们采用了以下配置:
Model | matches |
---|---|
BiGRU | None |
T6 | L6_hidden_mse + L6_hidden_smmd |
T3 | L3_hidden_mse + L3_hidden_smmd |
T3-small | L3n_hidden_mse + L3_hidden_smmd |
T4-Tiny | L4t_hidden_mse + L4_hidden_smmd |
T12-nano | small_hidden_mse + small_hidden_smmd |
Electra-small | small_hidden_mse + small_hidden_smmd |
各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。
在英文实验中,我们使用了如下三个典型数据集。
Dataset | Task type | Metrics | #Train | #Dev | Note |
---|---|---|---|---|---|
MNLI | 文本分类 | m/mm Acc | 393K | 20K | 句对三分类任务 |
SQuAD 1.1 | 阅读理解 | EM/F1 | 88K | 11K | 篇章片段抽取型阅读理解 |
CoNLL-2003 | 序列标注 | F1 | 23K | 6K | 命名实体识别任务 |
我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。
Public results:
Model (public) | MNLI | SQuAD | CoNLL-2003 |
---|---|---|---|
DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
BERT6-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | - |
BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
BERT3-PKD (T3) | 76.7 / 76.3 | - | - |
TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | - |
Our results:
Model (ours) | MNLI | SQuAD | CoNLL-2003 |
---|---|---|---|
BERT-base-cased (教师) | 83.7 / 84.0 | 81.5 / 88.6 | 91.1 |
BiGRU | - | - | 85.3 |
T6 | 83.5 / 84.0 | 80.8 / 88.1 | 90.7 |
T3 | 81.8 / 82.7 | 76.4 / 84.9 | 87.5 |
T3-small | 81.3 / 81.7 | 72.3 / 81.4 | 78.6 |
T4-tiny | 82.0 / 82.6 | 75.2 / 84.0 | 89.1 |
T12-nano | 83.2 / 83.9 | 79.0 / 86.6 | 89.6 |
说明:
在中文实验中,我们使用了如下典型数据集。
Dataset | Task type | Metrics | #Train | #Dev | Note |
---|---|---|---|---|---|
XNLI | 文本分类 | Acc | 393K | 2.5K | MNLI的中文翻译版本,3分类任务 |
LCQMC | 文本分类 | Acc | 239K | 8.8K | 句对二分类任务,判断两个句子的语义是否相同 |
CMRC 2018 | 阅读理解 | EM/F1 | 10K | 3.4K | 篇章片段抽取型阅读理解 |
DRCD | 阅读理解 | EM/F1 | 27K | 3.5K | 繁体中文篇章片段抽取型阅读理解 |
MSRA NER | 序列标注 | F1 | 45K | 3.4K (测试集) | 中文命名实体识别 |
实验结果如下表所示。
Model | XNLI | LCQMC | CMRC 2018 | DRCD |
---|---|---|---|---|
RoBERTa-wwm-ext (教师) | 79.9 | 89.4 | 68.8 / 86.4 | 86.5 / 92.5 |
T3 | 78.4 | 89.0 | 66.4 / 84.2 | 78.2 / 86.4 |
T3-small | 76.0 | 88.1 | 58.0 / 79.3 | 75.8 / 84.8 |
T4-tiny | 76.2 | 88.4 | 61.8 / 81.8 | 77.3 / 86.1 |
Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER |
---|---|---|---|---|---|
Electra-base (教师) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 |
Electra-small | 77.7 | 89.3 | 66.5 / 84.9 | 85.5 / 91.3 | 93.48 |
说明:
TrainingConfig
和 DistillationConfig
:训练和蒸馏相关的配置。Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
BasicDistiller
: 提供单模型单任务蒸馏方式。可用作测试或简单实验。GeneralDistiller
(常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用。MultiTeacherDistiller
: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配。MultiTaskDistiller
:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。BasicTrainer
:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型。蒸馏实验中,有两个组件需要由用户提供,分别是callback 和 adaptor :
回调函数。在每个checkpoint,保存模型后会被distiller
调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。
将模型的输入和输出转换为指定的格式,向distiller
解释模型的输入和输出,以便distiller
根据不同的策略进行不同的计算。在每个训练步,batch
和模型的输出model_outputs
会作为参数传递给adaptor
,adaptor
负责重新组织这些数据,返回一个字典。
更多细节可参见完整文档中的说明。
Q: 学生模型该如何初始化?
A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
Q: 如何设置蒸馏的训练参数以达到一个较好的效果?
A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考。
Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?
A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?
A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
如果TextBrewer工具包对你的研究工作有所帮助,请在文献中引用我们的论文:
@InProceedings{textbrewer-acl2020-demo,
title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
year = "2020",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-demos.2",
pages = "9--16",
}
欢迎关注哈工大讯飞联合实验室官方微信公众号,了解最新的技术动态。
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
1. 开源生态
2. 协作、人、软件
3. 评估模型