@@ -25,13 +25,19 @@ Data must be in the `str` format as detailed in the example below:
25
25
from pytree.data import prepare_input_from_constituency_tree
26
26
27
27
parse_tree_example = ' (TOP (S (NP (_ I)) (VP (_ saw) (NP (_ Sarah)) (PP (_ with) (NP (_ a) (_ telescope)))) (_ .)))'
28
- input_test, head_idx_test = prepare_input_from_constituency_tree(parse_tree_example)
28
+ input_test, head_idx_test, head_idx_r_test, head_idx_l_test = prepare_input_from_constituency_tree(parse_tree_example)
29
29
30
30
print (input_test)
31
31
# ['[CLS]', 'I', 'saw', 'Sarah', 'with', 'a', 'telescope', '.', '[S]', '[S]', '[VP]', '[VP]', '[PP]', '[NP]']
32
32
33
33
print (head_idx_test)
34
- # [0, 8, 10, 10, 11, 12, 12, 7, 0, 7, 8, 9, 9, 11]
34
+ # [0, 9, 11, 11, 12, 13, 13, 8, 0, 8, 9, 10, 10, 12]
35
+
36
+ print (head_idx_r_test)
37
+ # [0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0]
38
+
39
+ print (head_idx_l_test)
40
+ # [0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1]
35
41
```
36
42
37
43
### Prepare dependency tree data
@@ -68,17 +74,19 @@ from pytree.data.glove_tokenizer import GloveTokenizer
68
74
glove_tokenizer = GloveTokenizer(glove_file_path = ' ./glove.6B.300d.txt' , vocab_size = 10000 )
69
75
input_test = glove_tokenizer.convert_tokens_to_ids(input_test)
70
76
print (input_test)
71
- # [1, 1, 824, 1, 19, 9, 1, 4, 1, 1, 1, 1, 1, 1 ]
77
+ # [1, 1, 824, 1, 19, 9, 1, 4]
72
78
```
73
79
74
80
Then prepare the data:
75
81
76
82
``` python
77
- tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test)
83
+ from pytree.data.utils import build_tree_ids_n_ary
84
+
85
+ tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test, head_idx_r_test, head_idx_l_test)
78
86
inputs = {' input_ids' : torch.tensor(input_test).unsqueeze(0 ),
79
- ' packed_tree ' : torch.tensor(tree_ids_test).unsqueeze(0 ),
80
- ' packed_tree_r ' : torch.tensor(tree_ids_test_r).unsqueeze(0 ),
81
- ' packed_tree_l ' : torch.tensor(tree_ids_test_l).unsqueeze(0 )}
87
+ ' tree_ids ' : torch.tensor(tree_ids_test).unsqueeze(0 ),
88
+ ' tree_ids_r ' : torch.tensor(tree_ids_test_r).unsqueeze(0 ),
89
+ ' tree_ids_l ' : torch.tensor(tree_ids_test_l).unsqueeze(0 )}
82
90
```
83
91
84
92
And apply the model:
@@ -89,17 +97,19 @@ from pytree.models import NaryConfig, NaryTree
89
97
config = NaryConfig()
90
98
tree_encoder = NaryTree(config)
91
99
92
- tree_encoder(inputs)
100
+ (h, c), h_root = tree_encoder(inputs)
101
+ print (h)
93
102
# tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
94
- # [ 0.0012, 0.0015, -0.0026 , ..., -0.0001 , 0.0002 , -0.0043 ],
95
- # [ 0.0022, 0.0024, -0.0035 , ..., -0.0002 , 0.0003 , -0.0058 ],
103
+ # [ 0.0113, -0.0066, 0.0089 , ..., 0.0064 , 0.0076 , -0.0048 ],
104
+ # [ 0.0110, -0.0073, 0.0110 , ..., 0.0070 , 0.0046 , -0.0049 ],
96
105
# ...,
97
- # [ 0.0028, 0.0023, -0.0035 , ..., -0.0002 , 0.0003 , -0.0057 ],
98
- # [ 0.0020, 0.0016, -0.0023 , ..., -0.0001 , 0.0002 , -0.0036 ],
99
- # [ 0.0019, 0.0015, -0.0024 , ..., -0.0001 , 0.0002 , -0.0039 ]]],
100
- # grad_fn=<MaskedScatterBackward >)
106
+ # [ 0.0254, -0.0138, 0.0224 , ..., 0.0131 , 0.0148 , -0.0143 ],
107
+ # [ 0.0346, -0.0172, 0.0281 , ..., 0.0140 , 0.0198 , -0.0267 ],
108
+ # [ 0.0247, -0.0126, 0.0201 , ..., 0.0116 , 0.0162 , -0.0184 ]]],
109
+ # grad_fn=<SWhereBackward >)
101
110
102
- print (tree_encoder(inputs) .shape)
103
- # tree_encoder(inputs).shape
111
+ print (h_root .shape)
112
+ # torch.Size([150])
104
113
```
105
114
115
+ We also provide a full demonstration with the SICK dataset and batched processing in the [ examples folder] ( https://github.com/AntoineSimoulin/pytree/tree/main/examples ) .
0 commit comments