Skip to content

Commit e749023

Browse files
authored
Update train.py
1 parent e24880e commit e749023

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

train.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def _get_train_sampler(self) :
9292
)
9393

9494
def compute_loss(self, model, inputs, return_outputs=False):
95-
for task_id in inputs['task_name']:
96-
assert task_id==inputs['task_name'][0],f"Examples in the same batch should come from the same task, " \
97-
f"but task {task_id} and task {inputs['task_name'][0]} are found"
95+
for task_id in inputs['task_id']:
96+
assert task_id==inputs['task_id'][0],f"Examples in the same batch should come from the same task, " \
97+
f"but task {task_id} and task {inputs['task_id'][0]} are found"
9898
cur_results = {}
9999
for k in ['query', 'pos', 'neg']:
100100
cur_inputs = {
@@ -447,12 +447,12 @@ def main():
447447
def get_examples_raw(old_examples_raw, total_n, real_batch_size):
448448
examples_raw = []
449449
for idx in range(0, total_n, real_batch_size):
450-
local_task_name = old_examples_raw[idx]['task_name']
450+
local_task_name = old_examples_raw[idx]['task_id']
451451
cur_batch = []
452452
include_batch = True
453453
for idx1 in range(idx, min(idx + real_batch_size, total_n)):
454-
if not old_examples_raw[idx1]['task_name'] == local_task_name:
455-
print(f'one batch in task {old_examples_raw[idx1]["task_name"]} is skipped')
454+
if not old_examples_raw[idx1]['task_id'] == local_task_name:
455+
print(f'one batch in task {old_examples_raw[idx1]["task_id"]} is skipped')
456456
include_batch = False
457457
break
458458
else:
@@ -478,7 +478,7 @@ def get_examples_raw(old_examples_raw, total_n, real_batch_size):
478478
train_examples_raw = train_examples_raw[:int(data_args.debug_mode)]
479479

480480
def get_dataset(examples_raw):
481-
examples = {'query':[],'pos':[],'neg':[],'task_name':[]}
481+
examples = {'query':[],'pos':[],'neg':[],'task_id':[]}
482482
task_name_map = {}
483483
total_num = len(examples_raw)
484484
task_count = 0
@@ -492,10 +492,10 @@ def get_dataset(examples_raw):
492492
cur_e[k][0] = ''
493493
assert cur_e[k][0].startswith('Represent ') or cur_e[k][0]==''
494494
examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k]))
495-
if not cur_e['task_name'] in task_name_map:
496-
task_name_map[cur_e['task_name']] = task_count
495+
if not cur_e['task_id'] in task_name_map:
496+
task_name_map[cur_e['task_id']] = task_count
497497
task_count += 1
498-
examples['task_name'].append(task_name_map[cur_e['task_name']])
498+
examples['task_id'].append(task_name_map[cur_e['task_id']])
499499
return examples
500500

501501
train_raw_datasets = DatasetDict({'train':Dataset.from_dict(get_dataset(train_examples_raw))})
@@ -530,7 +530,7 @@ def preprocess_function(examples):
530530
all_tokenized[k] = all_tokenized[k].tolist()
531531
for k in keys:
532532
all_tokenized[f'{key}_{k}'] = tokenized[k].tolist()
533-
all_tokenized['task_name'] = examples['task_name']
533+
all_tokenized['task_id'] = examples['task_id']
534534
return all_tokenized
535535

536536
train_dataset = train_raw_datasets["train"]

0 commit comments

Comments
 (0)