Skip to content

Commit 20bd6ad

Browse files
:bugs: fix some bugs
1 parent 7abf012 commit 20bd6ad

File tree

10 files changed

+304
-26
lines changed

10 files changed

+304
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
model/

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<p>
66

77
**PyTree** implements tree-structured neural networks in PyTorch.
8-
The package provides highly generic recursive neural network implementations as well as efficient batching methods.
8+
The package provides highly generic tree-structured neural network implementations as well as efficient batching methods.
99

1010
## Installation
1111

examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ python pytree/examples/run_sick.py \
1414
--num_train_epochs 20
1515
```
1616

17+
```
18+
CUDA_VISIBLE_DEVICES=2 python examples/run_sick.py --glove_file_path /data/asimouli/GLOVE/glove.6B.300d.txt --do_train --do_eval --output_dir './model' --dataset_name 'sick' --remove_unused_columns False --learning_rate 0.05  --per_device_train_batch_size 25 --num_train_epochs 15 --weight_decay 1e-4 --lr_scheduler_type constant --do_predict --overwrite_cache True --overwrite_output_dir
19+
```
20+
1721
## References
1822

1923
> <div id="tai-2015">Kai Sheng Tai, Richard Socher, Christopher D. Manning <a href=https://aclanthology.org/P15-1150>Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks.</a> ACL (1) 2015: 1556-1566</div>

examples/main.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import logging
2+
import os
3+
os.environ['CUDA_VISIBLE_DEVICES'] = ""
4+
import sys
5+
from dataclasses import dataclass, field
6+
from typing import Optional
7+
8+
import datasets
9+
from datasets import load_dataset, load_metric
10+
11+
import transformers
12+
from transformers import (
13+
Trainer,
14+
EvalPrediction,
15+
HfArgumentParser,
16+
TrainingArguments,
17+
set_seed,
18+
default_data_collator,
19+
)
20+
from transformers.trainer_utils import get_last_checkpoint
21+
from transformers.utils import check_min_version
22+
from transformers.utils.versions import require_version
23+
# from utils_qa import postprocess_qa_predictions
24+
25+
from pytree import (
26+
NaryConfig,
27+
NaryTree,
28+
ChildSumConfig,
29+
ChildSumTree,
30+
GloveTokenizer,
31+
Similarity,
32+
SimilarityConfig
33+
)
34+
from pytree.data import prepare_input_from_constituency_tree, prepare_input_from_dependency_tree
35+
from pytree.data.utils import build_tree_ids_n_ary
36+
37+
from supar import Parser
38+
import torch
39+
import numpy as np
40+
import math
41+
from sklearn.metrics import mean_squared_error
42+
from scipy.stats import pearsonr, spearmanr
43+
44+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
45+
check_min_version("4.11.0") # 4.12.0.dev0
46+
47+
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
48+
49+
logger = logging.getLogger(__name__)
50+
51+
class SickTrainer(Trainer):
52+
53+
def create_optimizer(self):
54+
"""
55+
Setup the optimizer.
56+
57+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
58+
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
59+
"""
60+
self.optimizer = torch.optim.Adagrad(self.model.parameters(), lr=0.025, weight_decay=self.args.weight_decay)
61+
62+
# if self.sharded_ddp == ShardedDDPOption.SIMPLE:
63+
# self.optimizer = OSS(
64+
# params=optimizer_grouped_parameters,
65+
# optim=optimizer_cls,
66+
# **optimizer_kwargs,
67+
# )
68+
# else:
69+
# self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
70+
71+
# if is_sagemaker_mp_enabled():
72+
# self.optimizer = smp.DistributedOptimizer(self.optimizer)
73+
74+
return self.optimizer
75+
76+
77+
con = Parser.load('crf-con-en')
78+
glove_tokenizer = GloveTokenizer(glove_file_path='/data/asimouli/GLOVE/glove.6B.300d.txt', vocab_size=10000)
79+
80+
config = NaryConfig()
81+
encoder = NaryTree(config)
82+
encoder.embeddings.load_pretrained_embeddings(
83+
torch.tensor(glove_tokenizer.embeddings_arr))
84+
config_similarity = SimilarityConfig()
85+
model = Similarity(encoder, config_similarity)
86+
87+
raw_datasets = load_dataset('sick')
88+
column_names = raw_datasets["train"].column_names
89+
90+
def map_label_to_target(label, num_classes):
91+
target = [0] * num_classes # torch.zeros(1, num_classes, dtype=torch.float)
92+
ceil = int(math.ceil(label))
93+
floor = int(math.floor(label))
94+
if ceil == floor:
95+
target[floor - 1] = 1
96+
else:
97+
target[floor - 1] = ceil - label
98+
target[ceil - 1] = label - floor
99+
return target
100+
101+
def prepare_train_features(examples):
102+
examples['input_ids_A'] = []
103+
examples['input_ids_B'] = []
104+
examples['head_idx_A'] = []
105+
examples['head_idx_B'] = []
106+
examples['labels'] = []
107+
108+
for sent_A in examples['sentence_A']:
109+
con_tree_A = str(con.predict(sent_A.split(), verbose=False)[0])
110+
input_ids_A, head_idx_A = prepare_input_from_constituency_tree(con_tree_A)
111+
input_ids_A = glove_tokenizer.convert_tokens_to_ids(input_ids_A)
112+
examples['input_ids_A'].append(input_ids_A)
113+
examples['head_idx_A'].append(head_idx_A)
114+
115+
for sent_B in examples['sentence_B']:
116+
con_tree_B = str(con.predict(sent_B.split(), verbose=False)[0])
117+
input_ids_B, head_idx_B = prepare_input_from_constituency_tree(con_tree_B)
118+
input_ids_B = glove_tokenizer.convert_tokens_to_ids(input_ids_B)
119+
examples['input_ids_B'].append(input_ids_B)
120+
examples['head_idx_B'].append(head_idx_B)
121+
122+
for rel_score in examples['relatedness_score']:
123+
examples['labels'].append(map_label_to_target(rel_score, 5))
124+
125+
return examples
126+
127+
training_args = TrainingArguments(
128+
learning_rate=0.025,
129+
per_device_train_batch_size=25,
130+
num_train_epochs=20,
131+
weight_decay=1e-4,
132+
lr_scheduler_type='constant',
133+
output_dir="/home/asimouli/PhD/PyTree/pytree_remote/model",
134+
do_train=True,
135+
do_eval=True,
136+
remove_unused_columns=False)
137+
138+
train_examples = raw_datasets["train"]
139+
with training_args.main_process_first(desc="train dataset map pre-processing"):
140+
train_dataset = train_examples.map(
141+
prepare_train_features,
142+
batched=True,
143+
num_proc=None,
144+
remove_columns=None,
145+
load_from_cache_file=True,
146+
desc="Running parser on train dataset",
147+
)
148+
149+
# # Validation preprocessing
150+
151+
eval_examples = raw_datasets["validation"]
152+
eval_dataset = eval_examples.map(
153+
prepare_train_features,
154+
batched=True,
155+
num_proc=None,
156+
remove_columns=None, # column_names,
157+
desc="Running parser on validation dataset",
158+
)
159+
160+
def data_collator_with_padding(features, pad_ids=0, columns=None):
161+
batch = {}
162+
first = features[0]
163+
if columns is None:
164+
columns = ["head_idx_A", "head_idx_B", "input_ids_A", "input_ids_B"]
165+
feature_max_len = {k: max([len(f[k]) for f in features]) for k in first.keys() if k in columns or len(columns) == 0}
166+
for k, v in first.items():
167+
if k in columns or len(columns) == 0:
168+
feature_padded = [list([int(ff) for ff in f[k]]) + [0] * (feature_max_len[k] - len(f[k])) for f in features]
169+
batch[k] = feature_padded # [f[k] for f in features]
170+
tree_ids_A, tree_ids_r_A, tree_ids_l_A = build_tree_ids_n_ary(batch['head_idx_A'])
171+
tree_ids_B, tree_ids_r_B, tree_ids_l_B = build_tree_ids_n_ary(batch['head_idx_B'])
172+
batch['input_ids_A'] = torch.tensor(batch['input_ids_A'])
173+
batch['input_ids_B'] = torch.tensor(batch['input_ids_B'])
174+
batch['tree_ids_A'] = torch.tensor(tree_ids_A)
175+
batch['tree_ids_B'] = torch.tensor(tree_ids_B)
176+
batch['tree_ids_r_A'] = torch.tensor(tree_ids_r_A)
177+
batch['tree_ids_r_B'] = torch.tensor(tree_ids_r_B)
178+
batch['tree_ids_l_A'] = torch.tensor(tree_ids_l_A)
179+
batch['tree_ids_l_B'] = torch.tensor(tree_ids_l_B)
180+
batch['labels'] = torch.tensor([f['labels'] for f in features])
181+
return batch
182+
183+
data_collator = data_collator_with_padding
184+
185+
def compute_metrics(eval_prediction):
186+
prediction = np.matmul(np.exp(eval_prediction.predictions), np.arange(1, 5 + 1))
187+
target = np.matmul(eval_prediction.label_ids, np.arange(1, 5 + 1))
188+
results_relatedness = {
189+
'pearson': pearsonr(prediction, target)[0] * 100,
190+
'spearman': spearmanr(prediction, target)[0] * 100,
191+
'mse': mean_squared_error(prediction, target) * 100
192+
}
193+
return results_relatedness
194+
195+
trainer = SickTrainer(
196+
model=model,
197+
args=training_args,
198+
train_dataset=train_dataset,
199+
eval_dataset=eval_dataset,
200+
data_collator=data_collator,
201+
compute_metrics=compute_metrics,
202+
optimizers=("Adagrad", None),
203+
)
204+
205+
# Training
206+
207+
train_result = trainer.train(resume_from_checkpoint=None)
208+
trainer.save_model() # Saves the tokenizer too for easy upload
209+
210+
metrics = train_result.metrics
211+
max_train_samples = len(train_dataset)
212+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
213+
214+
trainer.log_metrics("train", metrics)
215+
trainer.save_metrics("train", metrics)
216+
trainer.save_state()
217+
218+
219+
logger.info("*** Evaluate ***")
220+
metrics = trainer.evaluate()
221+
222+
max_eval_samples = len(eval_dataset)
223+
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
224+
225+
trainer.log_metrics("eval", metrics)
226+
trainer.save_metrics("eval", metrics)
227+

examples/run_sick.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,33 @@
6767

6868
logger = logging.getLogger(__name__)
6969

70+
class SickTrainer(Trainer):
71+
72+
def create_optimizer(self):
73+
"""
74+
Setup the optimizer.
75+
76+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
77+
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
78+
"""
79+
self.optimizer = torch.optim.Adagrad(self.model.parameters(),
80+
lr=self.args.learning_rate,
81+
weight_decay=self.args.weight_decay)
82+
83+
# if self.sharded_ddp == ShardedDDPOption.SIMPLE:
84+
# self.optimizer = OSS(
85+
# params=optimizer_grouped_parameters,
86+
# optim=optimizer_cls,
87+
# **optimizer_kwargs,
88+
# )
89+
# else:
90+
# self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
91+
92+
# if is_sagemaker_mp_enabled():
93+
# self.optimizer = smp.DistributedOptimizer(self.optimizer)
94+
95+
return self.optimizer
96+
7097

7198
@dataclass
7299
class ModelArguments:
@@ -278,7 +305,7 @@ def main():
278305
if data_args.dataset_name is not None:
279306
# Downloading and loading a dataset from the hub.
280307
raw_datasets = load_dataset(
281-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
308+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False,
282309
)
283310
else:
284311
data_files = {}
@@ -303,7 +330,7 @@ def main():
303330
# download model & vocab.
304331
# dep = Parser.load('biaffine-dep-en')
305332
con = Parser.load('crf-con-en')
306-
glove_tokenizer = GloveTokenizer(glove_file_path=data_args.glove_file_path, vocab_size=10000)
333+
glove_tokenizer = GloveTokenizer(glove_file_path=data_args.glove_file_path, vocab_size=100000)
307334
# config = ChildSumConfig()
308335
# encoder = ChildSumTree(config)
309336
config = NaryConfig()
@@ -577,15 +604,15 @@ def compute_metrics(eval_prediction):
577604
# return metric.compute(predictions=p.predictions, references=p.label_ids)
578605

579606
# Initialize our Trainer
580-
trainer = Trainer(
607+
trainer = SickTrainer(
581608
model=model,
582609
args=training_args,
583610
train_dataset=train_dataset if training_args.do_train else None,
584-
eval_dataset=eval_dataset if training_args.do_eval else None,
611+
eval_dataset=predict_dataset if training_args.do_eval else None, # eval_dataset
585612
# eval_examples=eval_examples if training_args.do_eval else None,
586613
data_collator=data_collator,
587614
compute_metrics=compute_metrics,
588-
# optimizers=(torch.optim.Adagrad(model.parameters(), weight_decay=1e-4), None),
615+
optimizers=("Adagrad", None),
589616
)
590617
# trainer = QuestionAnsweringTrainer(
591618
# model=model,

pytree/data/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,14 @@ def get_nodes_detph(head_idx):
8282
return depth
8383

8484

85+
# def pad_tree_ids(tree_ids, depth):
86+
# tree_depth = tree_ids.shape[0]
87+
# padding = np.zeros((max(depth - tree_depth, 0), tree_ids.shape[1]), dtype=tree_ids.dtype)
88+
# return np.concatenate((tree_ids, padding), axis=0)
8589
def pad_tree_ids(tree_ids, depth):
8690
tree_depth = tree_ids.shape[0]
8791
padding = np.zeros((max(depth - tree_depth, 0), tree_ids.shape[1]), dtype=tree_ids.dtype)
88-
return np.concatenate((tree_ids, padding), axis=0)
89-
92+
return np.concatenate((padding, tree_ids), axis=0)
9093

9194
# def build_tree_ids(head_idx):
9295
# if isinstance(head_idx[0], list):
@@ -125,12 +128,18 @@ def build_tree_ids_n_ary(head_idx):
125128
np.array([pad_tree_ids(t[2], depth) for t in tree_ids])
126129
tree_ids = []
127130
node_idx = [get_root(head_idx)]
131+
# while len(node_idx) > 0:
132+
# node_idx = get_childrens(node_idx, head_idx)
133+
# tree_step = [h_idx if idx in node_idx else 0 for idx, h_idx in enumerate(head_idx)]
134+
# tree_ids.append(tree_step)
135+
# tree_ids = tree_ids[:-1]
136+
# tree_ids.append(range(0, len(head_idx)))
128137
while len(node_idx) > 0:
129138
node_idx = get_childrens(node_idx, head_idx)
130139
tree_step = [h_idx if idx in node_idx else 0 for idx, h_idx in enumerate(head_idx)]
131-
tree_ids.append(tree_step)
132-
tree_ids = tree_ids[:-1]
133-
tree_ids.append(range(0, len(head_idx)))
140+
tree_ids.insert(0, tree_step)
141+
tree_ids = tree_ids[1:]
142+
tree_ids.insert(0, range(0, len(head_idx)))
134143
tree_ids_r = [[t if (i % 2 == 0) else 0 for (i, t) in enumerate(ti)] for ti in tree_ids]
135144
tree_ids_d = [[t if (i % 2 == 1) else 0 for (i, t) in enumerate(ti)] for ti in tree_ids]
136145
return np.array(tree_ids), np.array(tree_ids_r), np.array(tree_ids_d)

0 commit comments

Comments
 (0)