基于Bert的低俗内容过滤

本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务

1.项目简介

在小米AI实验室NLP平台组内容过滤的工作,记录于此,已删去敏感信息,本文介绍基于Bert预训练模型Fine-tuning实现文本分类任务。

2.运行环境

1
pip install -r requirements.txt

Tensorflow>=1.11.0

Python2和Python3均可运行

3.准备数据

  • train.tsv:训练集,对应do_train参数
  • dev.tsv:验证集,对应do_eval参数
  • test.tsv:测试集,对应do_predict参数

格式均为Label+Tab+Text,如下所示

1
2
3
4
2	宝妈********
2 装修污染惹不注意男性**出问题
1 曹操家选媳妇********
0 一对夫妻打架 妻子被丈夫打得身上都是灰

5.模型训练及模型服务导出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
python run_classifier.py \
--task_name=multi \
--label_list=0/1/2 \
--do_train=true \
--do_eval=true \
--do_serve=true \
--data_dir=/porn_text
--vocab_file=/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=/chinese_L-12_H-768_A-12/bert_model.ckpt \
--max_seq_length=64 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=/porn_text/output \
--export_dir=/porn_text/output/export
  • 参数说明
参数 说明
task_name Processor的名字
label_list 指定分类标签
do_train 是否做Fine-tuning,需要train.tsv
do_eval 是否运行验证集,需要dev.tsv
do_serve 是否导出模型服务
data_dir 数据集的文件夹路径
vocab_file 字典地址,使用bert预训练模型
bert_config_file 配置文件,使用bert预训练模型
init_checkpoint 检查点,使用bert预训练模型
max_seq_length seq长度,设置为64
train_batch_size batch size,设置为32
learning_rate 学习率,设置为2e-5
num_train_epochs epoch设置为3.0
output_dir Fine-tuning后模型的保存地址
export_dir 模型服务的保存地址

6.预测及结果评估

1
2
3
4
5
6
7
8
9
10
python run_classifier.py \
--task_name=multi \
--do_predict=true \
--label_list=0/1/2 \
--data_dir=/porn_text \
--vocab_file=/chinese_L-12_H-768_A-12/vocab.txt \
--bert_config_file=/chinese_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=/porn_text/output \
--max_seq_length=64 \
--output_dir=/porn_text/predict_result
  • 参数说明
参数 说明
task_name Processor的名字
label_list 指定分类标签
do_predict 是否运行测试集,需要test.tsv
data_dir 数据集的文件夹路径
vocab_file 字典地址,使用bert预训练模型
bert_config_file 配置文件,使用bert预训练模型
init_checkpoint 检查点,修改为Fine-tuning的模型地址
max_seq_length seq长度,设置为64
output_dir 预测结果的保存地址,结果保存在test_results.tsv
  • 结果评估

预测结果保存在output_dir/test_results.tsv

使用sklearn计算test.tsv的评价指标

7.模型服务部署

参考 https://www.tensorflow.org/tfx/guide/serving