Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Commit 47dfa4f

Browse files
Multi Task Learning implementation and example (#778)
* Start working on #408 * Add example file for simple MTL * Add taks_name to formatted predictions. * Should not have been added.
1 parent 84f67e0 commit 47dfa4f

File tree

4 files changed

+208
-13
lines changed

4 files changed

+208
-13
lines changed

examples/mtl01_tclass_tclass.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
5+
import logging
6+
import torch
7+
import farm
8+
from farm.modeling.tokenization import Tokenizer
9+
from farm.data_handler.processor import TextClassificationProcessor
10+
from farm.data_handler.data_silo import DataSilo
11+
from farm.modeling.language_model import LanguageModel
12+
from farm.modeling.prediction_head import TextClassificationHead
13+
from farm.modeling.adaptive_model import AdaptiveModel
14+
from farm.modeling.optimization import initialize_optimizer
15+
from farm.train import Trainer
16+
from farm.infer import Inferencer
17+
from farm.eval import Evaluator
18+
19+
20+
21+
print("Pytorch version:", torch.__version__)
22+
print("CUDA library in pytorch:", torch.version.cuda)
23+
print("FARM version:", farm.__version__)
24+
25+
26+
#logger = MLFlowLogger(tracking_uri="mlflowlog01")
27+
#logger.init_experiment(experiment_name="farm_building_blocks", run_name="tutorial")
28+
29+
30+
logging.basicConfig(level="INFO")
31+
logger = logging.getLogger(name="mtl01-train")
32+
33+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34+
print("Devices available: {}".format(device))
35+
36+
37+
LANG_MODEL = "bert-base-german-cased"
38+
BATCH_SIZE = 32
39+
MAX_SEQ_LEN = 128
40+
EMBEDS_DROPOUT_PROB = 0.1
41+
LEARNING_RATE = 3e-5
42+
MAX_N_EPOCHS = 6
43+
N_GPU = 1
44+
EVAL_EVERY = 70
45+
DATA_DIR = "../data/germeval18"
46+
PREDICT = "both" # coarse, fine or both
47+
DO_ROUND_ROBIN = False # round robin training of heads?
48+
49+
50+
logger.info("Loading Tokenizer")
51+
tokenizer = Tokenizer.load(
52+
pretrained_model_name_or_path=LANG_MODEL,
53+
do_lower_case=False)
54+
55+
56+
LABEL_LIST_COARSE = ["OTHER", "OFFENSE"]
57+
LABEL_LIST_FINE = ["OTHER", "ABUSE", "INSULT", "PROFANITY"]
58+
59+
metrics_fine = "f1_macro"
60+
metrics_coarse = "f1_macro"
61+
62+
63+
processor = TextClassificationProcessor(tokenizer=tokenizer,
64+
max_seq_len=MAX_SEQ_LEN,
65+
data_dir=DATA_DIR,
66+
dev_split=0.1,
67+
text_column_name="text",
68+
)
69+
prediction_heads = []
70+
if PREDICT == "coarse" or PREDICT == "both":
71+
processor.add_task(name="coarse",
72+
task_type="classification",
73+
label_list=LABEL_LIST_COARSE,
74+
metric=metrics_coarse,
75+
text_column_name="text",
76+
label_column_name="coarse_label")
77+
prediction_head_coarse = TextClassificationHead(
78+
num_labels=len(LABEL_LIST_COARSE),
79+
task_name="coarse",
80+
class_weights=None)
81+
prediction_heads.append(prediction_head_coarse)
82+
if PREDICT == "fine" or PREDICT == "both":
83+
processor.add_task(name="fine",
84+
task_type="classification",
85+
label_list=LABEL_LIST_FINE,
86+
metric=metrics_fine,
87+
text_column_name="text",
88+
label_column_name="fine_label")
89+
prediction_head_fine = TextClassificationHead(
90+
num_labels=len(LABEL_LIST_FINE),
91+
task_name="fine",
92+
class_weights=None)
93+
prediction_heads.append(prediction_head_fine)
94+
# processor.save("mtl01-model")
95+
96+
data_silo = DataSilo(
97+
processor=processor,
98+
batch_size=BATCH_SIZE)
99+
100+
language_model = LanguageModel.load(LANG_MODEL)
101+
102+
103+
def loss_round_robin(tensors, global_step, batch=None):
104+
if global_step % 2:
105+
return tensors[0]
106+
else:
107+
return tensors[1]
108+
109+
110+
if PREDICT == "both" and DO_ROUND_ROBIN:
111+
loss_fn = loss_round_robin
112+
else:
113+
loss_fn = None
114+
115+
116+
model = AdaptiveModel(
117+
language_model=language_model,
118+
prediction_heads=prediction_heads,
119+
embeds_dropout_prob=EMBEDS_DROPOUT_PROB,
120+
lm_output_types=["per_sequence", "per_sequence"],
121+
loss_aggregation_fn=loss_fn,
122+
device=device)
123+
124+
125+
model, optimizer, lr_schedule = initialize_optimizer(
126+
model=model,
127+
device=device,
128+
learning_rate=LEARNING_RATE,
129+
n_batches=len(data_silo.loaders["train"]),
130+
n_epochs=MAX_N_EPOCHS)
131+
132+
133+
trainer = Trainer(
134+
model=model,
135+
optimizer=optimizer,
136+
data_silo=data_silo,
137+
epochs=MAX_N_EPOCHS,
138+
n_gpu=N_GPU,
139+
lr_schedule=lr_schedule,
140+
evaluate_every=EVAL_EVERY,
141+
device=device,
142+
)
143+
144+
145+
logger.info("Starting training")
146+
model = trainer.train()
147+
# model.save("mtl01-model")
148+
149+
150+
inferencer = Inferencer(model=model,
151+
processor=processor,
152+
batch_size=4, gpu=True,
153+
# TODO: how to mix for multihead?
154+
task_type="classification"
155+
)
156+
basic_texts = [
157+
{"text": "Some text you want to classify"},
158+
{"text": "A second sample"},
159+
]
160+
161+
162+
ret = inferencer.inference_from_dicts(basic_texts)
163+
logger.info(f"Result of inference: {ret}")
164+
165+
logger.info(f"Evaluating on training set...")
166+
evaluator = Evaluator(
167+
data_loader=data_silo.get_data_loader("train"),
168+
tasks=processor.tasks,
169+
device=device)
170+
171+
result = evaluator.eval(
172+
inferencer.model,
173+
return_preds_and_labels=True)
174+
175+
evaluator.log_results(
176+
result,
177+
"Test",
178+
steps=len(data_silo.get_data_loader("test")))
179+
180+
inferencer.close_multiprocessing_pool()
181+
logger.info("PROCESSING FINISHED")
182+

farm/infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,9 @@ def _get_predictions(self, dataset, tensor_names, baskets):
580580

581581
# get logits
582582
with torch.no_grad():
583-
logits = self.model.forward(**batch)[0]
583+
logits = self.model.forward(**batch)
584584
preds = self.model.formatted_preds(
585-
logits=[logits],
585+
logits=logits,
586586
samples=batch_samples,
587587
tokenizer=self.processor.tokenizer,
588588
return_class_probs=self.return_class_probs,

farm/modeling/adaptive_model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,19 @@ def formatted_preds(self, logits, **kwargs):
110110
# This case is triggered by Natural Questions
111111
else:
112112
preds_final = [list() for _ in range(n_heads)]
113-
preds = kwargs["preds"]
114-
preds_for_heads = stack(preds)
115-
logits_for_heads = [None] * n_heads
116-
117-
samples = [s for b in kwargs["baskets"] for s in b.samples]
118-
kwargs["samples"] = samples
113+
preds = kwargs.get("preds")
114+
if preds is not None:
115+
preds_for_heads = stack(preds)
116+
logits_for_heads = [None] * n_heads
117+
del kwargs["preds"]
118+
else:
119+
preds_for_heads = [None] * n_heads
120+
logits_for_heads = logits
121+
preds_final = [list() for _ in range(n_heads)]
119122

120-
del kwargs["preds"]
123+
if not "samples" in kwargs:
124+
samples = [s for b in kwargs["baskets"] for s in b.samples]
125+
kwargs["samples"] = samples
121126

122127
for i, (head, preds_for_head, logits_for_head) in enumerate(zip(self.prediction_heads, preds_for_heads, logits_for_heads)):
123128
preds = head.formatted_preds(logits=logits_for_head, preds=preds_for_head, **kwargs)

farm/modeling/prediction_head.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def formatted_preds(self, logits, samples, **kwargs):
235235
preds = self.logits_to_preds(logits)
236236
contexts = [sample.clear_text["text"] for sample in samples]
237237

238-
res = {"task": "regression", "predictions": []}
238+
res = {"task": "regression",
239+
"task_name": self.task_name,
240+
"predictions": []}
239241
for pred, context in zip(preds, contexts):
240242
res["predictions"].append(
241243
{
@@ -410,7 +412,9 @@ def formatted_preds(self, logits=None, preds=None, samples=None, return_class_pr
410412
if len(contexts_b) != 0:
411413
contexts = ["|".join([a, b]) for a,b in zip(contexts, contexts_b)]
412414

413-
res = {"task": "text_classification", "predictions": []}
415+
res = {"task": "text_classification",
416+
"task_name": self.task_name,
417+
"predictions": []}
414418
for pred, prob, context in zip(preds, probs, contexts):
415419
if not return_class_probs:
416420
pred_dict = {
@@ -526,7 +530,9 @@ def formatted_preds(self, logits, samples, **kwargs):
526530
probs = self.logits_to_probs(logits)
527531
contexts = [sample.clear_text["text"] for sample in samples]
528532

529-
res = {"task": "text_classification", "predictions": []}
533+
res = {"task": "text_classification",
534+
"task_name": self.task_name,
535+
"predictions": []}
530536
for pred, prob, context in zip(preds, probs, contexts):
531537
res["predictions"].append(
532538
{
@@ -693,7 +699,9 @@ def formatted_preds(self, logits, initial_mask, samples, return_class_probs=Fals
693699

694700
# align back with original input by getting the original word spans
695701
spans = [s.tokenized["word_spans"] for s in samples]
696-
res = {"task": "ner", "predictions": []}
702+
res = {"task": "ner",
703+
"task_name": self.task_name,
704+
"predictions": []}
697705
for preds_seq, probs_seq, sample, spans_seq in zip(
698706
preds, probs, samples, spans
699707
):

0 commit comments

Comments
 (0)