本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务
1.项目简介
在小米AI实验室NLP平台组内容过滤的工作,记录于此,已删去敏感信息,本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务。
2.运行环境
1 | pip install -r requirements.txt |
Tensorflow>=1.11.0
Python2和Python3均可运行
3.准备数据
- train.tsv:训练集,对应do_train参数
- dev.tsv:验证集,对应do_eval参数
- test.tsv:测试集,对应do_predict参数
格式均为Label+Tab+Text,如下所示1
2
3
42 宝妈********
2 装修污染惹不注意男性**出问题
1 曹操家选媳妇********
0 一对夫妻打架 妻子被丈夫打得身上都是灰
5.模型训练及模型服务导出
1 | python run_classifier.py \ |
- 参数说明
| 参数 | 说明 |
|---|---|
| task_name | Processor的名字 |
| label_list | 指定分类标签 |
| do_train | 是否做Fine-tuning,需要train.tsv |
| do_eval | 是否运行验证集,需要dev.tsv |
| do_serve | 是否导出模型服务 |
| data_dir | 数据集的文件夹路径 |
| vocab_file | 字典地址,使用bert预训练模型 |
| bert_config_file | 配置文件,使用bert预训练模型 |
| init_checkpoint | 检查点,使用bert预训练模型 |
| max_seq_length | seq长度,设置为64 |
| train_batch_size | batch size,设置为32 |
| learning_rate | 学习率,设置为2e-5 |
| num_train_epochs | epoch设置为3.0 |
| output_dir | Fine-tuning后模型的保存地址 |
| export_dir | 模型服务的保存地址 |
6.预测及结果评估
1 | python run_classifier.py \ |
- 参数说明
| 参数 | 说明 |
|---|---|
| task_name | Processor的名字 |
| label_list | 指定分类标签 |
| do_predict | 是否运行测试集,需要test.tsv |
| data_dir | 数据集的文件夹路径 |
| vocab_file | 字典地址,使用bert预训练模型 |
| bert_config_file | 配置文件,使用bert预训练模型 |
| init_checkpoint | 检查点,修改为Fine-tuning的模型地址 |
| max_seq_length | seq长度,设置为64 |
| output_dir | 预测结果的保存地址,结果保存在test_results.tsv |
- 结果评估
预测结果保存在output_dir/test_results.tsv
使用sklearn计算test.tsv的评价指标