1
1
import torch
2
2
import torch .nn as nn
3
- from torch .utils .data import DataLoader , random_split
3
+ from torch .utils .data import Dataset , DataLoader , random_split
4
4
5
5
# Distributed training
6
6
from torch .utils .data .distributed import DistributedSampler
25
25
26
26
from model import build_transformer
27
27
from dataset import BilingualDataset , causal_mask
28
- from config import get_default_config , get_weights_file_path , get_latest_weights_file_path
28
+ from config import get_default_config , get_weights_file_path , get_latest_weights_file_path , ModelConfig
29
29
30
- def greedy_decode (model , source , source_mask , tokenizer_src , tokenizer_tgt , max_len , device ):
30
+ def greedy_decode (model : nn . Module , source : torch . Tensor , source_mask : torch . Tensor , tokenizer_src : Tokenizer , tokenizer_tgt : Tokenizer , max_len : int , device : torch . device ):
31
31
sos_idx = tokenizer_tgt .token_to_id ('[SOS]' )
32
32
eos_idx = tokenizer_tgt .token_to_id ('[EOS]' )
33
33
@@ -58,7 +58,7 @@ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_
58
58
return decoder_input .squeeze (0 )
59
59
60
60
61
- def run_validation (model , validation_ds , tokenizer_src , tokenizer_tgt , max_len , device , print_msg , global_step , num_examples = 2 ):
61
+ def run_validation (model : nn . Module , validation_ds : DataLoader , tokenizer_src : Tokenizer , tokenizer_tgt : Tokenizer , max_len : int , device : torch . device , print_msg : callable , global_step : int , num_examples : int = 2 ):
62
62
model .eval ()
63
63
count = 0
64
64
@@ -122,11 +122,11 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
122
122
bleu = metric (predicted , expected )
123
123
wandb .log ({'validation/BLEU' : bleu , 'global_step' : global_step })
124
124
125
- def get_all_sentences (ds , lang ):
125
+ def get_all_sentences (ds : Dataset , lang : str ):
126
126
for item in ds :
127
127
yield item ['translation' ][lang ]
128
128
129
- def get_or_build_tokenizer (config , ds , lang ) :
129
+ def get_or_build_tokenizer (config : ModelConfig , ds : Dataset , lang : str ) -> Tokenizer :
130
130
tokenizer_path = Path (config .tokenizer_file .format (lang ))
131
131
if not Path .exists (tokenizer_path ):
132
132
# Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
@@ -139,7 +139,7 @@ def get_or_build_tokenizer(config, ds, lang):
139
139
tokenizer = Tokenizer .from_file (str (tokenizer_path ))
140
140
return tokenizer
141
141
142
- def get_ds (config ):
142
+ def get_ds (config : ModelConfig ):
143
143
# It only has the train split, so we divide it overselves
144
144
ds_raw = load_dataset ('opus_books' , f"{ config .lang_src } -{ config .lang_tgt } " , split = 'train' )
145
145
@@ -177,11 +177,11 @@ def get_ds(config):
177
177
178
178
return train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt
179
179
180
- def get_model (config , vocab_src_len , vocab_tgt_len ):
180
+ def get_model (config : ModelConfig , vocab_src_len : int , vocab_tgt_len : int ):
181
181
model = build_transformer (vocab_src_len , vocab_tgt_len , config .seq_len , config .seq_len , d_model = config .d_model )
182
182
return model
183
183
184
- def train_model (config ):
184
+ def train_model (config : ModelConfig ):
185
185
# Define the device
186
186
assert torch .cuda .is_available (), "Training on CPU is not supported"
187
187
device = torch .device ("cuda" )
@@ -324,9 +324,9 @@ def train_model(config):
324
324
parser .add_argument ('--model_basename' , type = str , default = config .model_basename )
325
325
parser .add_argument ('--preload' , type = str , default = config .preload )
326
326
parser .add_argument ('--tokenizer_file' , type = str , default = config .tokenizer_file )
327
+ args = parser .parse_args ()
327
328
328
329
# Update default configuration with command line arguments
329
- args = parser .parse_args ()
330
330
config .__dict__ .update (vars (args ))
331
331
332
332
# Add local rank and global rank to the config
0 commit comments