Skip to content

Commit 7784d7a

Browse files
🎉 working sick example
1 parent 20bd6ad commit 7784d7a

File tree

11 files changed

+745
-199
lines changed

11 files changed

+745
-199
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
model/
1+
model/
2+
*.pyc
3+
__pycache__/

README.md

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@ Data must be in the `str` format as detailed in the example below:
2525
from pytree.data import prepare_input_from_constituency_tree
2626

2727
parse_tree_example = '(TOP (S (NP (_ I)) (VP (_ saw) (NP (_ Sarah)) (PP (_ with) (NP (_ a) (_ telescope)))) (_ .)))'
28-
input_test, head_idx_test = prepare_input_from_constituency_tree(parse_tree_example)
28+
input_test, head_idx_test, head_idx_r_test, head_idx_l_test = prepare_input_from_constituency_tree(parse_tree_example)
2929

3030
print(input_test)
3131
# ['[CLS]', 'I', 'saw', 'Sarah', 'with', 'a', 'telescope', '.', '[S]', '[S]', '[VP]', '[VP]', '[PP]', '[NP]']
3232

3333
print(head_idx_test)
34-
# [0, 8, 10, 10, 11, 12, 12, 7, 0, 7, 8, 9, 9, 11]
34+
# [0, 9, 11, 11, 12, 13, 13, 8, 0, 8, 9, 10, 10, 12]
35+
36+
print(head_idx_r_test)
37+
# [0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0]
38+
39+
print(head_idx_l_test)
40+
# [0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1]
3541
```
3642

3743
### Prepare dependency tree data
@@ -68,17 +74,19 @@ from pytree.data.glove_tokenizer import GloveTokenizer
6874
glove_tokenizer = GloveTokenizer(glove_file_path='./glove.6B.300d.txt', vocab_size=10000)
6975
input_test = glove_tokenizer.convert_tokens_to_ids(input_test)
7076
print(input_test)
71-
# [1, 1, 824, 1, 19, 9, 1, 4, 1, 1, 1, 1, 1, 1]
77+
# [1, 1, 824, 1, 19, 9, 1, 4]
7278
```
7379

7480
Then prepare the data:
7581

7682
```python
77-
tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test)
83+
from pytree.data.utils import build_tree_ids_n_ary
84+
85+
tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test, head_idx_r_test, head_idx_l_test)
7886
inputs = {'input_ids': torch.tensor(input_test).unsqueeze(0),
79-
'packed_tree': torch.tensor(tree_ids_test).unsqueeze(0),
80-
'packed_tree_r': torch.tensor(tree_ids_test_r).unsqueeze(0),
81-
'packed_tree_l': torch.tensor(tree_ids_test_l).unsqueeze(0)}
87+
'tree_ids': torch.tensor(tree_ids_test).unsqueeze(0),
88+
'tree_ids_r': torch.tensor(tree_ids_test_r).unsqueeze(0),
89+
'tree_ids_l': torch.tensor(tree_ids_test_l).unsqueeze(0)}
8290
```
8391

8492
And apply the model:
@@ -89,17 +97,19 @@ from pytree.models import NaryConfig, NaryTree
8997
config = NaryConfig()
9098
tree_encoder = NaryTree(config)
9199

92-
tree_encoder(inputs)
100+
(h, c), h_root = tree_encoder(inputs)
101+
print(h)
93102
# tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
94-
# [ 0.0012, 0.0015, -0.0026, ..., -0.0001, 0.0002, -0.0043],
95-
# [ 0.0022, 0.0024, -0.0035, ..., -0.0002, 0.0003, -0.0058],
103+
# [ 0.0113, -0.0066, 0.0089, ..., 0.0064, 0.0076, -0.0048],
104+
# [ 0.0110, -0.0073, 0.0110, ..., 0.0070, 0.0046, -0.0049],
96105
# ...,
97-
# [ 0.0028, 0.0023, -0.0035, ..., -0.0002, 0.0003, -0.0057],
98-
# [ 0.0020, 0.0016, -0.0023, ..., -0.0001, 0.0002, -0.0036],
99-
# [ 0.0019, 0.0015, -0.0024, ..., -0.0001, 0.0002, -0.0039]]],
100-
# grad_fn=<MaskedScatterBackward>)
106+
# [ 0.0254, -0.0138, 0.0224, ..., 0.0131, 0.0148, -0.0143],
107+
# [ 0.0346, -0.0172, 0.0281, ..., 0.0140, 0.0198, -0.0267],
108+
# [ 0.0247, -0.0126, 0.0201, ..., 0.0116, 0.0162, -0.0184]]],
109+
# grad_fn=<SWhereBackward>)
101110

102-
print(tree_encoder(inputs).shape)
103-
# tree_encoder(inputs).shape
111+
print(h_root.shape)
112+
# torch.Size([150])
104113
```
105114

115+
We also provide a full demonstration with the SICK dataset and batched processing in the [examples folder](https://github.com/AntoineSimoulin/pytree/tree/main/examples).

examples/README.md

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,72 @@
11
Implementation of ([Tai et al., 2015](#tai-2015))
22

3+
For the Constituency TreeLSTM, you can run the following script:
34

45
```bash
5-
python pytree/examples/run_sick.py \
6-
--glove_file_path ./glove.6B.300d.txt \
6+
python examples/run_sick_n_ary.py \
7+
--glove_file_path glove.840B.300d.txt \
78
--do_train \
89
--do_eval \
10+
--do_predict \
911
--output_dir './model' \
1012
--dataset_name 'sick' \
11-
--remove_unused_columns False \
12-
--learning_rate 0.025 \
13-
--per_device_train_batch_size 25 \
14-
--num_train_epochs 20
13+
--remove_unused_columns false \
14+
--learning_rate 0.05 \
15+
--per_device_train_batch_size 25 \
16+
--num_train_epochs 10 \
17+
--weight_decay 1e-4 \
18+
--lr_scheduler_type constant \
19+
--overwrite_cache false \
20+
--overwrite_output_dir \
21+
--evaluation_strategy epoch
1522
```
1623

24+
You should get the following results:
25+
26+
```bash
27+
***** predict metrics *****
28+
predict_samples = 4906
29+
test_loss = 0.6236
30+
test_mse = 31.8074
31+
test_pearson = 83.2404
32+
test_runtime = 0:00:13.02
33+
test_samples_per_second = 376.716
34+
test_spearman = 77.1604
35+
test_steps_per_second = 47.147
1736
```
18-
CUDA_VISIBLE_DEVICES=2 python examples/run_sick.py --glove_file_path /data/asimouli/GLOVE/glove.6B.300d.txt --do_train --do_eval --output_dir './model' --dataset_name 'sick' --remove_unused_columns False --learning_rate 0.05  --per_device_train_batch_size 25 --num_train_epochs 15 --weight_decay 1e-4 --lr_scheduler_type constant --do_predict --overwrite_cache True --overwrite_output_dir
37+
38+
For the Dependency TreeLSTM, you can run the following script:
39+
40+
```bash
41+
python examples/run_sick_child_sum.py \
42+
--glove_file_path glove.840B.300d.txt \
43+
--do_train \
44+
--do_eval \
45+
--do_predict \
46+
--output_dir './model' \
47+
--dataset_name 'sick' \
48+
--remove_unused_columns false \
49+
--learning_rate 0.05 \
50+
--per_device_train_batch_size 25 \
51+
--num_train_epochs 5 \
52+
--weight_decay 1e-4 \
53+
--lr_scheduler_type constant \
54+
--overwrite_cache true \
55+
--overwrite_output_dir
56+
```
57+
58+
You should get the following results:
59+
60+
```bash
61+
***** predict metrics *****
62+
predict_samples = 4906
63+
test_loss = 0.5228
64+
test_mse = 26.4252
65+
test_pearson = 86.3953
66+
test_runtime = 0:00:05.74
67+
test_samples_per_second = 854.158
68+
test_spearman = 80.3738
69+
test_steps_per_second = 106.9
1970
```
2071

2172
## References

0 commit comments

Comments
 (0)