model: original model that is used for quantization or distillation, this is the student model w.r.t. knowledge distillation.
optimizer (for model): optimizer for optimizating the model.
teacher model: the teacher model. In pure_distillation, this is the larger model. In model quantization, this is the original uncompressed model or a larger model.
optimizer_param_group (the parameter_group that is used for initialize the parameter group)。
SNIP
1.项目介绍
SNIP 为一款集成了“精调+蒸馏+模型加速”一体的模型上线助力工具。用户只需要提供经过 post-train 以后的预训练模型和业务精调数据,并且修改 3~8 行训练代码, SNIP就会自动输出加速后的模型。
通过对多种量化与蒸馏算法的深度优化,我们相比实现了更高的加速效果,同时在多业务上都成功适配了行之有效的无损加速手段。
当前初步在AI、视频、文章相关场景取得了不错的效果。
工具包基于python3.8开发,支持作为一个python包进行使用,同时本身作为一个工具包,对环境的依赖程度低。
注意:目前Auto48X只支持调用自己的modeling文件(目前只支持BertForSequenceClassification,如果需要扩展支持请联系负责人)。用户需要自己intialize:
model: original model that is used for quantization or distillation, this is the student model w.r.t. knowledge distillation.
optimizer (for model): optimizer for optimizating the model.
teacher model: the teacher model. In pure_distillation, this is the larger model. In model quantization, this is the original uncompressed model or a larger model.
optimizer_param_group (the parameter_group that is used for initialize the parameter group)。
2.功能特性
轻量级
工具包作为Python 包,执行简单安装后,通过简单导入即可使用。
易使用
用户只需要修改七行用户脚本代码即可训练量化后的模型
适用于神经网络量化与蒸馏训练场景
3.快速上手
Build SNIP:
目前Auto48X支持pip instal 使用(setuptools==50.3.2)以及直接import项目文件夹使用两种方式。
requires:
python >= 3.8
transformers==4.19.0
torch better be 1.11.0
-——————————————————————————————————————
EXAMPLE
见 Auto48/example/ 路径下脚本。
BUILD STEP
1)配置Auto48X所需参数。
Auto48X默认参数配置使用 Auto48X/config/auto48_default.json 文件配置。用户如需自定义Auto48超参,可通过args和json文件配置参数,args优先级高于json文件。
修改parser用来接收Auto48X所需参数(分别放在解析parser之前和之后):
2)初始化Teacher model:
3)初始化Auto48引擎:
4)运行forward:
注意这里forward不能直接model(batch)运行,需要用engine.engine_forward(batch)来运行。返回的第一个值是正常的BERT的返回类。(这个只需要在training的时候调用,evaluate的时候正常model(batch)就好了)
5)添加蒸馏:
注意这里num_input是每个sample的token数,可以直接用
6)进行backward并更新模型:
-——————————————————————————————————————
*如果是从老版本Auto48X迁移过来的用户,一共有三处需要修改的地方:
1)outputs, _, _ = engine.engine_forward(batch) 修改成 outputs = engine.engine_forward(batch)。
2)model.backward(loss)修改成engine.backward(loss)。
3)model.step()修改成engine.step()。
4)若直接使用model,model的输出是一个dict数据
-——————————————————————————————————————
SNIP启动教程:
Deepspeed :
Python3 :
使用DDP :
不使用DDP:
-——————————————————————————————————————
4.常见问题
Auto48X参数指引(具体定义可以参照auto48_utils.py):
注意事项:
(1)–qat、–pure_distillation、–pure_qat_eval、–pure_qat_eval 至少需要指定一项。
(2)命令台输入的args优先级高于指定的json文件。
(3)–pure_qat_eval模式需要模型先过Auto48Init后再load_from_ckpt。
(4)使用fp16训练需要安装apex包。
参数列表:
Mode
All hyper-parameter
–pure_qat_eval模式
1)配置Auto48所需参数
2)生成模型
蒸馏细节:
对于不同的用户脚本,支持数据集无label蒸馏和有label蒸馏。
Auto48X 中 engine.add_knowledge_distillation(loss, num_input) 内置了对student和teacher的output进行loss计算。 可以通过Input中loss设置为None时代表用户的loss无法从数据集label获得。