本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务
1.项目简介
在小米AI实验室NLP平台组内容过滤的工作,记录于此,已删去敏感信息,本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务。
2.运行环境
1 | pip install -r requirements.txt |
Tensorflow>=1.11.0
Python2和Python3均可运行
3.准备数据
- train.tsv:训练集,对应do_train参数
- dev.tsv:验证集,对应do_eval参数
- test.tsv:测试集,对应do_predict参数
格式均为Label+Tab+Text,如下所示1
2
3
42 宝妈********
2 装修污染惹不注意男性**出问题
1 曹操家选媳妇********
0 一对夫妻打架 妻子被丈夫打得身上都是灰
5.模型训练及模型服务导出
1 | python run_classifier.py \ |
- 参数说明
参数 | 说明 |
---|---|
task_name | Processor的名字 |
label_list | 指定分类标签 |
do_train | 是否做Fine-tuning,需要train.tsv |
do_eval | 是否运行验证集,需要dev.tsv |
do_serve | 是否导出模型服务 |
data_dir | 数据集的文件夹路径 |
vocab_file | 字典地址,使用bert预训练模型 |
bert_config_file | 配置文件,使用bert预训练模型 |
init_checkpoint | 检查点,使用bert预训练模型 |
max_seq_length | seq长度,设置为64 |
train_batch_size | batch size,设置为32 |
learning_rate | 学习率,设置为2e-5 |
num_train_epochs | epoch设置为3.0 |
output_dir | Fine-tuning后模型的保存地址 |
export_dir | 模型服务的保存地址 |
6.预测及结果评估
1 | python run_classifier.py \ |
- 参数说明
参数 | 说明 |
---|---|
task_name | Processor的名字 |
label_list | 指定分类标签 |
do_predict | 是否运行测试集,需要test.tsv |
data_dir | 数据集的文件夹路径 |
vocab_file | 字典地址,使用bert预训练模型 |
bert_config_file | 配置文件,使用bert预训练模型 |
init_checkpoint | 检查点,修改为Fine-tuning的模型地址 |
max_seq_length | seq长度,设置为64 |
output_dir | 预测结果的保存地址,结果保存在test_results.tsv |
- 结果评估
预测结果保存在output_dir/test_results.tsv
使用sklearn计算test.tsv的评价指标