Skip to content

Commit b5839fd

Browse files
committed
update new code
1 parent ff679b1 commit b5839fd

28 files changed

+3608
-21841
lines changed

.gitignore

Lines changed: 0 additions & 131 deletions
This file was deleted.

README.md

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,77 @@
22
GPLinker_pytorch
33

44
# 介绍
5-
这是pytorch版本的GPLinker,该代码执行效率可能有点慢,主要瓶颈应该在datagenerator部分,之后可能会使用dataloader加载数据。
6-
本仓库主要参考了[苏神博客](https://kexue.fm/archives/8888)[他的keras版本代码](https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py)
5+
这是pytorch版本的`GPLinker`代码以及`TPLinker_Plus`代码。
6+
- `GPLinker`主要参考了[苏神博客](https://kexue.fm/archives/8888)[他的keras版本代码](https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction_gplinker.py)
7+
- `TPLinker_Plus`主要参考了[原版代码](https://github.com/131250208/TPlinker-joint-extraction/tree/master/tplinker_plus)
8+
- 其中`TPLinker_Plus`代码在模型部分可能有点区别。
9+
10+
# 更新
11+
- 2022/02/25 现已在Dev分支更新最新的huggingface全家桶版本的代码,main分支是之前旧的代码(执行效率慢)
712

813
# 依赖
9-
- transformers
10-
- torch
14+
所需的依赖如下:
15+
- fastcore==1.3.29
16+
- datasets==1.18.3
17+
- transformers>=4.16.2
18+
- accelerate==0.5.1
19+
- chinesebert==0.2.1
20+
安装依赖requirements.txt
21+
```bash
22+
pip install -r requirements.txt
23+
```
24+
# 准备数据
25+
http://ai.baidu.com/broad/download?dataset=sked 下载数据。
26+
`train_data.json``dev_data.json`压缩成`spo.zip`文件,并且放入`data`文件夹。
27+
当前`data/spo.zip`文件是本人提供精简后的数据集,其中`train_data.json`只有1000条数据,`dev_data.json`只有100条数据。
1128

1229
# 运行
1330
```bash
14-
python run.py
15-
```
16-
可修改文件中的参数。
17-
```python
18-
efficient = False # 是否使用EfficientGlobalpointer
19-
epochs = 20
20-
maxlen = 128
21-
batch_size = 16
22-
weight_decay = 0.01
23-
lr = 3e-5
24-
dict_path = "./chinese-roberta-wwm-ext/vocab.txt" # 预训练模型vocab.txt路径
25-
model_name_or_path = "hfl/chinese-roberta-wwm-ext" # 预训练模型权重路径
31+
accelerate launch train.py \
32+
--model_type bert \
33+
--pretrained_model_name_or_path bert-base-chinese \
34+
--method gplinker \
35+
--logging_steps 200 \
36+
--num_train_epochs 20 \
37+
--learning_rate 3e-5 \
38+
--num_warmup_steps_or_radios 0.1 \
39+
--gradient_accumulation_steps 1 \
40+
--per_device_train_batch_size 16 \
41+
--per_device_eval_batch_size 32 \
42+
--seed 42 \
43+
--save_steps 10804 \
44+
--output_dir ./outputs \
45+
--max_length 128 \
46+
--topk 1 \
47+
--num_workers 6
2648
```
49+
其中使用到参数介绍如下:
50+
- `model_type`: 表示模型架构类型,像`bert-base-chinese``hfl/chinese-roberta-wwm-ext`模型都是基于`bert`架构,`junnyu/roformer_chinese_char_base`是基于`roformer`架构,可选择`["bert", "roformer", "chinesebert"]`
51+
- `pretrained_model_name_or_path`: 表示加载的预训练模型权重,可以是本地目录,也可以是`huggingface.co`的路径。
52+
- `method`: 表示使用的方法, 可选择`["gplinker", "tplinker_plus"]`
53+
- `logging_steps`: 日志打印的间隔,默认为`200`
54+
- `num_train_epochs`: 训练轮数,默认为`20`
55+
- `learning_rate`: 学习率,默认为`3e-5`
56+
- `num_warmup_steps_or_radios`: `warmup`步数或者比率,当为`浮点类型`时候表示的是`radio`,当为`整型`时候表示的是`step`,默认为`0.1`
57+
- `gradient_accumulation_steps`: 梯度累计的步数,默认为`1`
58+
- `per_device_train_batch_size`: 训练的batch_size,默认为`16`
59+
- `per_device_eval_batch_size`: 评估的batch_size,默认为`32`
60+
- `seed`: 随机种子,以便于复现,默认为`42`
61+
- `save_steps`: 保存步数,每隔多少步保存模型。
62+
- `output_dir`: 模型输出路径。
63+
- `max_length`: 句子的最大长度,当大于这个长度时候,`tokenizer`会进行截断处理。
64+
- `topk`: 保存`topk`个数模型,默认为`1`
65+
- `num_workers`: `dataloader``num_workers`参数,`linux`系统下发现`GPU`使用率不高的时候可以尝试设置这个参数大于`0`,而`windows`下最好设置为`0`,不然会报错。
66+
2767

2868
# 结果
29-
```bash
30-
#Epoch 1 -- f1 : 0.6860628101678229, precision : 0.8071369146660334, recall : 0.5965740639068857
31-
==================================================
32-
#Epoch 2 -- f1 : 0.7958765733219821, precision : 0.8358021409757634, recall : 0.759591523004283
33-
==================================================
34-
#Epoch 3 -- f1 : 0.8108524322855335, precision : 0.836991841410104, recall : 0.7862962556275397
35-
==================================================
36-
#Epoch 4 -- f1 : 0.8135480049539608, precision : 0.798203719357566, recall : 0.8294937959811138
37-
==================================================
38-
#Epoch 5 -- f1 : 0.8235228089080464, precision : 0.8422611530778595, recall : 0.8056000878445156
39-
==================================================
40-
#Epoch 6 -- f1 : 0.825081844489883, precision : 0.8398234919479578, recall : 0.8108487976281985
41-
==================================================
42-
```
69+
Tips: `gplinker``RTX3090`条件下要训练`5-6h`
70+
| method | pretrained_model_name_or_path | f1 | precision | recall |
71+
| -------- | ----------------------------- | ------------------ | ------------------ | ------------------ |
72+
| gplinker | hfl/chinese-roberta-wwm-ext | 0.8214065255731926 | 0.8250077498782166 | 0.8178366038895478 |
73+
| gplinker | bert-base-chinese | 0.8198087178424598 | 0.8146470447994109 | 0.8250362175688137 |
74+
75+
# Tensorboard日志
76+
<p align="center">
77+
<img src="figure/tensorboard_log.jpg" width="100%" />
78+
</p>

0 commit comments

Comments
 (0)