Skip to content

Commit 92d8020

Browse files
committed
added comments and type definitions
1 parent 4569908 commit 92d8020

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

config.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
@dataclass
55
class ModelConfig:
66

7-
batch_size: int
8-
num_epochs: int
9-
lr: float
10-
seq_len: int
11-
d_model: int
12-
lang_src: str
13-
lang_tgt: str
14-
model_folder: str
15-
model_basename: str
16-
preload: str
17-
tokenizer_file: str
18-
local_rank: int = -1
19-
global_rank: int = -1
7+
batch_size: int # Batch size
8+
num_epochs: int # Number of epochs to train
9+
lr: float # Learning rate
10+
seq_len: int # Sequence length
11+
d_model: int # Size of the embedding vector
12+
lang_src: str # Source language
13+
lang_tgt: str # Target language
14+
model_folder: str # Folder where to save the checkpoints
15+
model_basename: str # Basename of the checkpoint files
16+
preload: str # Preload weights from a previous checkpoint
17+
tokenizer_file: str # Path where to save the tokenizer
18+
local_rank: int = -1 # LOCAL_RANK assigned by torchrun
19+
global_rank: int = -1 # RANK assigned by torchrun
2020

2121
def get_default_config() -> ModelConfig:
2222

@@ -34,13 +34,13 @@ def get_default_config() -> ModelConfig:
3434
tokenizer_file="tokenizer_{0}.json",
3535
)
3636

37-
def get_weights_file_path(config, epoch: str) -> str:
37+
def get_weights_file_path(config: ModelConfig, epoch: str) -> str:
3838
model_folder = config.model_folder
3939
model_basename = config.model_basename
4040
model_filename = model_basename.format(epoch)
4141
return str(Path('.') / model_folder / model_filename)
4242

43-
def get_latest_weights_file_path(config) -> str:
43+
def get_latest_weights_file_path(config: ModelConfig) -> str:
4444
model_folder = config.model_folder
4545
model_basename = config.model_basename
4646
# Check all files in the model folder

dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len
2020
def __len__(self):
2121
return len(self.ds)
2222

23-
def __getitem__(self, idx):
23+
def __getitem__(self, idx: int):
2424
src_target_pair = self.ds[idx]
2525
src_text = src_target_pair['translation'][self.src_lang]
2626
tgt_text = src_target_pair['translation'][self.tgt_lang]
@@ -84,6 +84,6 @@ def __getitem__(self, idx):
8484
"tgt_text": tgt_text,
8585
}
8686

87-
def causal_mask(size):
87+
def causal_mask(size: int):
8888
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
8989
return mask == 0

train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from torch.utils.data import DataLoader, random_split
3+
from torch.utils.data import Dataset, DataLoader, random_split
44

55
# Distributed training
66
from torch.utils.data.distributed import DistributedSampler
@@ -25,9 +25,9 @@
2525

2626
from model import build_transformer
2727
from dataset import BilingualDataset, causal_mask
28-
from config import get_default_config, get_weights_file_path, get_latest_weights_file_path
28+
from config import get_default_config, get_weights_file_path, get_latest_weights_file_path, ModelConfig
2929

30-
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
30+
def greedy_decode(model: nn.Module, source: torch.Tensor, source_mask: torch.Tensor, tokenizer_src: Tokenizer, tokenizer_tgt: Tokenizer, max_len: int, device: torch.device):
3131
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
3232
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
3333

@@ -58,7 +58,7 @@ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_
5858
return decoder_input.squeeze(0)
5959

6060

61-
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=2):
61+
def run_validation(model: nn.Module, validation_ds: DataLoader, tokenizer_src: Tokenizer, tokenizer_tgt: Tokenizer, max_len: int, device: torch.device, print_msg: callable, global_step: int, num_examples: int = 2):
6262
model.eval()
6363
count = 0
6464

@@ -122,11 +122,11 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
122122
bleu = metric(predicted, expected)
123123
wandb.log({'validation/BLEU': bleu, 'global_step': global_step})
124124

125-
def get_all_sentences(ds, lang):
125+
def get_all_sentences(ds: Dataset, lang: str):
126126
for item in ds:
127127
yield item['translation'][lang]
128128

129-
def get_or_build_tokenizer(config, ds, lang):
129+
def get_or_build_tokenizer(config: ModelConfig, ds: Dataset, lang: str) -> Tokenizer:
130130
tokenizer_path = Path(config.tokenizer_file.format(lang))
131131
if not Path.exists(tokenizer_path):
132132
# Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
@@ -139,7 +139,7 @@ def get_or_build_tokenizer(config, ds, lang):
139139
tokenizer = Tokenizer.from_file(str(tokenizer_path))
140140
return tokenizer
141141

142-
def get_ds(config):
142+
def get_ds(config: ModelConfig):
143143
# It only has the train split, so we divide it overselves
144144
ds_raw = load_dataset('opus_books', f"{config.lang_src}-{config.lang_tgt}", split='train')
145145

@@ -177,11 +177,11 @@ def get_ds(config):
177177

178178
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
179179

180-
def get_model(config, vocab_src_len, vocab_tgt_len):
180+
def get_model(config: ModelConfig, vocab_src_len: int, vocab_tgt_len: int):
181181
model = build_transformer(vocab_src_len, vocab_tgt_len, config.seq_len, config.seq_len, d_model=config.d_model)
182182
return model
183183

184-
def train_model(config):
184+
def train_model(config: ModelConfig):
185185
# Define the device
186186
assert torch.cuda.is_available(), "Training on CPU is not supported"
187187
device = torch.device("cuda")
@@ -324,9 +324,9 @@ def train_model(config):
324324
parser.add_argument('--model_basename', type=str, default=config.model_basename)
325325
parser.add_argument('--preload', type=str, default=config.preload)
326326
parser.add_argument('--tokenizer_file', type=str, default=config.tokenizer_file)
327+
args = parser.parse_args()
327328

328329
# Update default configuration with command line arguments
329-
args = parser.parse_args()
330330
config.__dict__.update(vars(args))
331331

332332
# Add local rank and global rank to the config

0 commit comments

Comments
 (0)