Skip to content

Commit 4569908

Browse files
committed
using statically typed config class
1 parent 9601659 commit 4569908

File tree

2 files changed

+95
-72
lines changed

2 files changed

+95
-72
lines changed

config.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,48 @@
11
from pathlib import Path
2+
from dataclasses import dataclass
23

3-
def get_config():
4-
return {
5-
"batch_size": 4,
6-
"num_epochs": 30,
7-
"lr": 10**-4,
8-
"seq_len": 350,
9-
"d_model": 512,
10-
"lang_src": "en",
11-
"lang_tgt": "it",
12-
"model_folder": "weights",
13-
"model_basename": "tmodel_{0:02d}.pt",
14-
"preload": "latest",
15-
"tokenizer_file": "tokenizer_{0}.json",
16-
}
4+
@dataclass
5+
class ModelConfig:
6+
7+
batch_size: int
8+
num_epochs: int
9+
lr: float
10+
seq_len: int
11+
d_model: int
12+
lang_src: str
13+
lang_tgt: str
14+
model_folder: str
15+
model_basename: str
16+
preload: str
17+
tokenizer_file: str
18+
local_rank: int = -1
19+
global_rank: int = -1
20+
21+
def get_default_config() -> ModelConfig:
22+
23+
return ModelConfig(
24+
batch_size=4,
25+
num_epochs=30,
26+
lr=10**-4,
27+
seq_len=350,
28+
d_model=512,
29+
lang_src="en",
30+
lang_tgt="it",
31+
model_folder="weights",
32+
model_basename="tmodel_{0:02d}.pt",
33+
preload="latest",
34+
tokenizer_file="tokenizer_{0}.json",
35+
)
1736

1837
def get_weights_file_path(config, epoch: str) -> str:
19-
model_folder = config["model_folder"]
20-
model_basename = config["model_basename"]
38+
model_folder = config.model_folder
39+
model_basename = config.model_basename
2140
model_filename = model_basename.format(epoch)
2241
return str(Path('.') / model_folder / model_filename)
2342

2443
def get_latest_weights_file_path(config) -> str:
25-
model_folder = config["model_folder"]
26-
model_basename = config["model_basename"]
44+
model_folder = config.model_folder
45+
model_basename = config.model_basename
2746
# Check all files in the model folder
2847
model_files = Path(model_folder).glob(f"*.pt")
2948
# Sort by epoch number (ascending order)

train.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from torch.utils.data import DataLoader, random_split
44

55
# Distributed training
6-
import torch.multiprocessing as mp
76
from torch.utils.data.distributed import DistributedSampler
8-
from torch.nn.parallel import DistributedDataParallel as DDP
7+
from torch.nn.parallel import DistributedDataParallel
98
from torch.distributed import init_process_group, destroy_process_group
109

1110
import warnings
@@ -26,7 +25,7 @@
2625

2726
from model import build_transformer
2827
from dataset import BilingualDataset, causal_mask
29-
from config import get_config, get_weights_file_path, get_latest_weights_file_path
28+
from config import get_default_config, get_weights_file_path, get_latest_weights_file_path
3029

3130
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
3231
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
@@ -128,7 +127,7 @@ def get_all_sentences(ds, lang):
128127
yield item['translation'][lang]
129128

130129
def get_or_build_tokenizer(config, ds, lang):
131-
tokenizer_path = Path(config['tokenizer_file'].format(lang))
130+
tokenizer_path = Path(config.tokenizer_file.format(lang))
132131
if not Path.exists(tokenizer_path):
133132
# Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
134133
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
@@ -142,78 +141,79 @@ def get_or_build_tokenizer(config, ds, lang):
142141

143142
def get_ds(config):
144143
# It only has the train split, so we divide it overselves
145-
ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
144+
ds_raw = load_dataset('opus_books', f"{config.lang_src}-{config.lang_tgt}", split='train')
146145

147146
# Build tokenizers
148-
if config['local_rank'] == 0:
147+
if config.local_rank == 0:
149148
print("Loading tokenizers...")
150-
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
151-
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
149+
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config.lang_src)
150+
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config.lang_tgt)
152151

153152
# Keep 90% for training, 10% for validation
154153
train_ds_size = int(0.9 * len(ds_raw))
155154
val_ds_size = len(ds_raw) - train_ds_size
156155
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
157156

158-
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
159-
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
157+
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len)
158+
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len)
160159

161160
# Find the maximum length of each sentence in the source and target sentence
162161
max_len_src = 0
163162
max_len_tgt = 0
164163

165164
for item in ds_raw:
166-
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
167-
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
165+
src_ids = tokenizer_src.encode(item['translation'][config.lang_src]).ids
166+
tgt_ids = tokenizer_tgt.encode(item['translation'][config.lang_tgt]).ids
168167
max_len_src = max(max_len_src, len(src_ids))
169168
max_len_tgt = max(max_len_tgt, len(tgt_ids))
170169

171-
if config['local_rank'] == 0:
170+
if config.local_rank == 0:
172171
print(f'Max length of source sentence: {max_len_src}')
173172
print(f'Max length of target sentence: {max_len_tgt}')
174173

175174

176-
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=False, sampler=DistributedSampler(train_ds, shuffle=True))
175+
train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=False, sampler=DistributedSampler(train_ds, shuffle=True))
177176
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
178177

179178
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
180179

181180
def get_model(config, vocab_src_len, vocab_tgt_len):
182-
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
181+
model = build_transformer(vocab_src_len, vocab_tgt_len, config.seq_len, config.seq_len, d_model=config.d_model)
183182
return model
184183

185184
def train_model(config):
186185
# Define the device
187186
assert torch.cuda.is_available(), "Training on CPU is not supported"
188187
device = torch.device("cuda")
189-
if config['local_rank'] == 0:
188+
if config.local_rank == 0:
190189
print("Using device:", device)
191190

192191
# Make sure the weights folder exists
193-
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
192+
Path(config.model_folder).mkdir(parents=True, exist_ok=True)
194193

195194
# Load the dataset
196-
if config['local_rank'] == 0:
195+
if config.local_rank == 0:
197196
print("Loading dataset...")
198197
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
199198
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
200199

201200
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
202201

203-
# If the user specified a model to preload before training, load it
202+
# By default, load the latest checkpoint
204203
initial_epoch = 0
205204
global_step = 0
206205
wandb_run_id = None
207-
if config['preload'] != '':
206+
if config.preload != '':
208207

209-
if config['preload'] == 'latest':
208+
if config.preload == 'latest':
209+
# Get the filename of the latest checkpoint
210210
model_filename = get_latest_weights_file_path(config)
211211
else:
212-
model_filename = get_weights_file_path(config, int(config['preload']))
212+
# In case we want to preload a specific checkpoint
213+
model_filename = get_weights_file_path(config, int(config.preload))
213214

214-
# If we couldn't find a model to preload, just start from scratch
215215
if model_filename is not None:
216-
if config['local_rank'] == 0:
216+
if config.local_rank == 0:
217217
print(f'Preloading model {model_filename}')
218218
state = torch.load(model_filename)
219219
model.load_state_dict(state['model_state_dict'])
@@ -223,11 +223,12 @@ def train_model(config):
223223
wandb_run_id = state['wandb_run_id']
224224
del state
225225
else:
226-
if config['local_rank'] == 0:
227-
print(f'Could not find model to preload: {config["preload"]}. Starting from scratch')
226+
# If we couldn't find a model to preload, just start from scratch
227+
if config.local_rank == 0:
228+
print(f'Could not find model to preload: {config.preload}. Starting from scratch')
228229

229-
# Only initialize W&B on the rank 0 node
230-
if config['global_rank'] == 0:
230+
# Only initialize W&B on the global rank 0 node
231+
if config.global_rank == 0:
231232
wandb.init(
232233
# set the wandb project where this run will be logged
233234
project="pytorch-transformer-distributed",
@@ -240,22 +241,22 @@ def train_model(config):
240241

241242
# Convert the model to DistributedDataParallel
242243
# Here we can also specify the bucket_cap_mb parameter to control the size of the buckets
243-
model = DDP(model, device_ids=[config['local_rank']])
244+
model = DistributedDataParallel(model, device_ids=[config.local_rank])
244245

245246
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
246247

247-
if config['global_rank'] == 0:
248+
if config.global_rank == 0:
248249
# define our custom x axis metric
249250
wandb.define_metric("global_step")
250251
# define which metrics will be plotted against it
251252
wandb.define_metric("validation/*", step_metric="global_step")
252253
wandb.define_metric("train/*", step_metric="global_step")
253254

254-
for epoch in range(initial_epoch, config['num_epochs']):
255+
for epoch in range(initial_epoch, config.num_epochs):
255256
torch.cuda.empty_cache()
256257
model.train()
257-
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d} on rank {config['global_rank']}")
258-
if config['local_rank'] != 0:
258+
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d} on rank {config.global_rank}")
259+
if config.local_rank != 0:
259260
batch_iterator.disable = True
260261

261262
for batch in batch_iterator:
@@ -277,7 +278,7 @@ def train_model(config):
277278
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
278279
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}", "global_step": global_step})
279280

280-
if config['global_rank'] == 0:
281+
if config.global_rank == 0:
281282
# Log the loss
282283
wandb.log({'train/loss': loss.item(), 'global_step': global_step})
283284

@@ -291,9 +292,9 @@ def train_model(config):
291292
global_step += 1
292293

293294
# Only run validation and checkpoint saving on the rank 0 node
294-
if config['global_rank'] == 0:
295+
if config.global_rank == 0:
295296
# Run validation at the end of every epoch
296-
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
297+
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config.seq_len, device, lambda msg: batch_iterator.write(msg), global_step)
297298

298299
# Save the model at the end of every epoch
299300
model_filename = get_weights_file_path(config, epoch)
@@ -308,39 +309,42 @@ def train_model(config):
308309

309310
if __name__ == '__main__':
310311
warnings.filterwarnings("ignore")
311-
config = get_config()
312+
config = get_default_config()
312313

313314
# Read command line arguments and overwrite config accordingly
314315
parser = argparse.ArgumentParser()
315-
parser.add_argument('--batch_size', type=int, default=config['batch_size'])
316-
parser.add_argument('--num_epochs', type=int, default=config['num_epochs'])
317-
parser.add_argument('--lr', type=float, default=config['lr'])
318-
parser.add_argument('--seq_len', type=int, default=config['seq_len'])
319-
parser.add_argument('--d_model', type=int, default=config['d_model'])
320-
parser.add_argument('--lang_src', type=str, default=config['lang_src'])
321-
parser.add_argument('--lang_tgt', type=str, default=config['lang_tgt'])
322-
parser.add_argument('--model_folder', type=str, default=config['model_folder'])
323-
parser.add_argument('--model_basename', type=str, default=config['model_basename'])
324-
parser.add_argument('--preload', type=str, default=config['preload'])
325-
parser.add_argument('--tokenizer_file', type=str, default=config['tokenizer_file'])
316+
parser.add_argument('--batch_size', type=int, default=config.batch_size)
317+
parser.add_argument('--num_epochs', type=int, default=config.num_epochs)
318+
parser.add_argument('--lr', type=float, default=config.lr)
319+
parser.add_argument('--seq_len', type=int, default=config.seq_len)
320+
parser.add_argument('--d_model', type=int, default=config.d_model)
321+
parser.add_argument('--lang_src', type=str, default=config.lang_src)
322+
parser.add_argument('--lang_tgt', type=str, default=config.lang_tgt)
323+
parser.add_argument('--model_folder', type=str, default=config.model_folder)
324+
parser.add_argument('--model_basename', type=str, default=config.model_basename)
325+
parser.add_argument('--preload', type=str, default=config.preload)
326+
parser.add_argument('--tokenizer_file', type=str, default=config.tokenizer_file)
326327

327328
# Update default configuration with command line arguments
328329
args = parser.parse_args()
329-
config.update(vars(args))
330+
config.__dict__.update(vars(args))
330331

331332
# Add local rank and global rank to the config
332-
config['local_rank'] = int(os.environ['LOCAL_RANK'])
333-
config['global_rank'] = int(os.environ['RANK'])
333+
config.local_rank = int(os.environ['LOCAL_RANK'])
334+
config.global_rank = int(os.environ['RANK'])
335+
336+
assert config.local_rank != -1, "LOCAL_RANK environment variable not set"
337+
assert config.global_rank != -1, "RANK environment variable not set"
334338

335339
# Print configuration
336-
if config['local_rank'] == 0:
340+
if config.local_rank == 0:
337341
print("Configuration:")
338-
for key, value in config.items():
342+
for key, value in config.__dict__.items():
339343
print(f"{key:>20}: {value}")
340344

341345
# Setup distributed training
342346
init_process_group(backend='nccl')
343-
torch.cuda.set_device(config['local_rank'])
347+
torch.cuda.set_device(config.local_rank)
344348

345349
# Train the model
346350
train_model(config)

0 commit comments

Comments
 (0)