@@ -92,9 +92,9 @@ def _get_train_sampler(self) :
92
92
)
93
93
94
94
def compute_loss (self , model , inputs , return_outputs = False ):
95
- for task_id in inputs ['task_name ' ]:
96
- assert task_id == inputs ['task_name ' ][0 ],f"Examples in the same batch should come from the same task, " \
97
- f"but task { task_id } and task { inputs ['task_name ' ][0 ]} are found"
95
+ for task_id in inputs ['task_id ' ]:
96
+ assert task_id == inputs ['task_id ' ][0 ],f"Examples in the same batch should come from the same task, " \
97
+ f"but task { task_id } and task { inputs ['task_id ' ][0 ]} are found"
98
98
cur_results = {}
99
99
for k in ['query' , 'pos' , 'neg' ]:
100
100
cur_inputs = {
@@ -447,12 +447,12 @@ def main():
447
447
def get_examples_raw (old_examples_raw , total_n , real_batch_size ):
448
448
examples_raw = []
449
449
for idx in range (0 , total_n , real_batch_size ):
450
- local_task_name = old_examples_raw [idx ]['task_name ' ]
450
+ local_task_name = old_examples_raw [idx ]['task_id ' ]
451
451
cur_batch = []
452
452
include_batch = True
453
453
for idx1 in range (idx , min (idx + real_batch_size , total_n )):
454
- if not old_examples_raw [idx1 ]['task_name ' ] == local_task_name :
455
- print (f'one batch in task { old_examples_raw [idx1 ]["task_name " ]} is skipped' )
454
+ if not old_examples_raw [idx1 ]['task_id ' ] == local_task_name :
455
+ print (f'one batch in task { old_examples_raw [idx1 ]["task_id " ]} is skipped' )
456
456
include_batch = False
457
457
break
458
458
else :
@@ -478,7 +478,7 @@ def get_examples_raw(old_examples_raw, total_n, real_batch_size):
478
478
train_examples_raw = train_examples_raw [:int (data_args .debug_mode )]
479
479
480
480
def get_dataset (examples_raw ):
481
- examples = {'query' :[],'pos' :[],'neg' :[],'task_name ' :[]}
481
+ examples = {'query' :[],'pos' :[],'neg' :[],'task_id ' :[]}
482
482
task_name_map = {}
483
483
total_num = len (examples_raw )
484
484
task_count = 0
@@ -492,10 +492,10 @@ def get_dataset(examples_raw):
492
492
cur_e [k ][0 ] = ''
493
493
assert cur_e [k ][0 ].startswith ('Represent ' ) or cur_e [k ][0 ]== ''
494
494
examples [k ].append ('!@#$%^&**!@#$%^&**' .join (cur_e [k ]))
495
- if not cur_e ['task_name ' ] in task_name_map :
496
- task_name_map [cur_e ['task_name ' ]] = task_count
495
+ if not cur_e ['task_id ' ] in task_name_map :
496
+ task_name_map [cur_e ['task_id ' ]] = task_count
497
497
task_count += 1
498
- examples ['task_name ' ].append (task_name_map [cur_e ['task_name ' ]])
498
+ examples ['task_id ' ].append (task_name_map [cur_e ['task_id ' ]])
499
499
return examples
500
500
501
501
train_raw_datasets = DatasetDict ({'train' :Dataset .from_dict (get_dataset (train_examples_raw ))})
@@ -530,7 +530,7 @@ def preprocess_function(examples):
530
530
all_tokenized [k ] = all_tokenized [k ].tolist ()
531
531
for k in keys :
532
532
all_tokenized [f'{ key } _{ k } ' ] = tokenized [k ].tolist ()
533
- all_tokenized ['task_name ' ] = examples ['task_name ' ]
533
+ all_tokenized ['task_id ' ] = examples ['task_id ' ]
534
534
return all_tokenized
535
535
536
536
train_dataset = train_raw_datasets ["train" ]
0 commit comments