|
2 | 2 | GPLinker_pytorch
|
3 | 3 |
|
4 | 4 | # 介绍
|
5 |
| -这是pytorch版本的GPLinker,该代码执行效率可能有点慢,主要瓶颈应该在datagenerator部分,之后可能会使用dataloader加载数据。 |
6 |
| -本仓库主要参考了[苏神博客](https://kexue.fm/archives/8888)和[他的keras版本代码](https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py) |
| 5 | +这是pytorch版本的`GPLinker`代码以及`TPLinker_Plus`代码。 |
| 6 | +- `GPLinker`主要参考了[苏神博客](https://kexue.fm/archives/8888)和[他的keras版本代码](https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py) |
| 7 | +- `TPLinker_Plus`主要参考了[原版代码](https://github.com/131250208/TPlinker-joint-extraction/tree/master/tplinker_plus) |
| 8 | +- 其中`TPLinker_Plus`代码在模型部分可能有点区别。 |
| 9 | + |
| 10 | +# 更新 |
| 11 | +- 2022/02/25 现已在Dev分支更新最新的huggingface全家桶版本的代码,main分支是之前旧的代码(执行效率慢) |
7 | 12 |
|
8 | 13 | # 依赖
|
9 |
| -- transformers |
10 |
| -- torch |
| 14 | +所需的依赖如下: |
| 15 | +- fastcore==1.3.29 |
| 16 | +- datasets==1.18.3 |
| 17 | +- transformers>=4.16.2 |
| 18 | +- accelerate==0.5.1 |
| 19 | +- chinesebert==0.2.1 |
| 20 | +安装依赖requirements.txt |
| 21 | +```bash |
| 22 | +pip install -r requirements.txt |
| 23 | +``` |
| 24 | +# 准备数据 |
| 25 | +从 http://ai.baidu.com/broad/download?dataset=sked 下载数据。 |
| 26 | +将`train_data.json`和`dev_data.json`压缩成`spo.zip`文件,并且放入`data`文件夹。 |
| 27 | +当前`data/spo.zip`文件是本人提供精简后的数据集,其中`train_data.json`只有1000条数据,`dev_data.json`只有100条数据。 |
11 | 28 |
|
12 | 29 | # 运行
|
13 | 30 | ```bash
|
14 |
| -python run.py |
15 |
| -``` |
16 |
| -可修改文件中的参数。 |
17 |
| -```python |
18 |
| -efficient = False # 是否使用EfficientGlobalpointer |
19 |
| -epochs = 20 |
20 |
| -maxlen = 128 |
21 |
| -batch_size = 16 |
22 |
| -weight_decay = 0.01 |
23 |
| -lr = 3e-5 |
24 |
| -dict_path = "./chinese-roberta-wwm-ext/vocab.txt" # 预训练模型vocab.txt路径 |
25 |
| -model_name_or_path = "hfl/chinese-roberta-wwm-ext" # 预训练模型权重路径 |
| 31 | +accelerate launch train.py \ |
| 32 | + --model_type bert \ |
| 33 | + --pretrained_model_name_or_path bert-base-chinese \ |
| 34 | + --method gplinker \ |
| 35 | + --logging_steps 200 \ |
| 36 | + --num_train_epochs 20 \ |
| 37 | + --learning_rate 3e-5 \ |
| 38 | + --num_warmup_steps_or_radios 0.1 \ |
| 39 | + --gradient_accumulation_steps 1 \ |
| 40 | + --per_device_train_batch_size 16 \ |
| 41 | + --per_device_eval_batch_size 32 \ |
| 42 | + --seed 42 \ |
| 43 | + --save_steps 10804 \ |
| 44 | + --output_dir ./outputs \ |
| 45 | + --max_length 128 \ |
| 46 | + --topk 1 \ |
| 47 | + --num_workers 6 |
26 | 48 | ```
|
| 49 | +其中使用到参数介绍如下: |
| 50 | +- `model_type`: 表示模型架构类型,像`bert-base-chinese`、`hfl/chinese-roberta-wwm-ext`模型都是基于`bert`架构,`junnyu/roformer_chinese_char_base`是基于`roformer`架构,可选择`["bert", "roformer", "chinesebert"]`。 |
| 51 | +- `pretrained_model_name_or_path`: 表示加载的预训练模型权重,可以是本地目录,也可以是`huggingface.co`的路径。 |
| 52 | +- `method`: 表示使用的方法, 可选择`["gplinker", "tplinker_plus"]` |
| 53 | +- `logging_steps`: 日志打印的间隔,默认为`200`。 |
| 54 | +- `num_train_epochs`: 训练轮数,默认为`20`。 |
| 55 | +- `learning_rate`: 学习率,默认为`3e-5`。 |
| 56 | +- `num_warmup_steps_or_radios`: `warmup`步数或者比率,当为`浮点类型`时候表示的是`radio`,当为`整型`时候表示的是`step`,默认为`0.1`。 |
| 57 | +- `gradient_accumulation_steps`: 梯度累计的步数,默认为`1`。 |
| 58 | +- `per_device_train_batch_size`: 训练的batch_size,默认为`16`。 |
| 59 | +- `per_device_eval_batch_size`: 评估的batch_size,默认为`32`。 |
| 60 | +- `seed`: 随机种子,以便于复现,默认为`42`。 |
| 61 | +- `save_steps`: 保存步数,每隔多少步保存模型。 |
| 62 | +- `output_dir`: 模型输出路径。 |
| 63 | +- `max_length`: 句子的最大长度,当大于这个长度时候,`tokenizer`会进行截断处理。 |
| 64 | +- `topk`: 保存`topk`个数模型,默认为`1`。 |
| 65 | +- `num_workers`: `dataloader`的`num_workers`参数,`linux`系统下发现`GPU`使用率不高的时候可以尝试设置这个参数大于`0`,而`windows`下最好设置为`0`,不然会报错。 |
| 66 | + |
27 | 67 |
|
28 | 68 | # 结果
|
29 |
| -```bash |
30 |
| -#Epoch 1 -- f1 : 0.6860628101678229, precision : 0.8071369146660334, recall : 0.5965740639068857 |
31 |
| -================================================== |
32 |
| -#Epoch 2 -- f1 : 0.7958765733219821, precision : 0.8358021409757634, recall : 0.759591523004283 |
33 |
| -================================================== |
34 |
| -#Epoch 3 -- f1 : 0.8108524322855335, precision : 0.836991841410104, recall : 0.7862962556275397 |
35 |
| -================================================== |
36 |
| -#Epoch 4 -- f1 : 0.8135480049539608, precision : 0.798203719357566, recall : 0.8294937959811138 |
37 |
| -================================================== |
38 |
| -#Epoch 5 -- f1 : 0.8235228089080464, precision : 0.8422611530778595, recall : 0.8056000878445156 |
39 |
| -================================================== |
40 |
| -#Epoch 6 -- f1 : 0.825081844489883, precision : 0.8398234919479578, recall : 0.8108487976281985 |
41 |
| -================================================== |
42 |
| -``` |
| 69 | +Tips: `gplinker`在`RTX3090`条件下要训练`5-6h`。 |
| 70 | +| method | pretrained_model_name_or_path | f1 | precision | recall | |
| 71 | +| -------- | ----------------------------- | ------------------ | ------------------ | ------------------ | |
| 72 | +| gplinker | hfl/chinese-roberta-wwm-ext | 0.8214065255731926 | 0.8250077498782166 | 0.8178366038895478 | |
| 73 | +| gplinker | bert-base-chinese | 0.8198087178424598 | 0.8146470447994109 | 0.8250362175688137 | |
| 74 | + |
| 75 | +# Tensorboard日志 |
| 76 | +<p align="center"> |
| 77 | + <img src="figure/tensorboard_log.jpg" width="100%" /> |
| 78 | +</p> |
0 commit comments