Skip to content

Commit 26e8b00

Browse files
committed
printing debug info from all local nodes, but each indicating its rank
1 parent 92d8020 commit 26e8b00

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

train.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ def get_ds(config: ModelConfig):
144144
ds_raw = load_dataset('opus_books', f"{config.lang_src}-{config.lang_tgt}", split='train')
145145

146146
# Build tokenizers
147-
if config.local_rank == 0:
148-
print("Loading tokenizers...")
147+
print(f"GPU {config.local_rank} - Loading tokenizers...")
149148
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config.lang_src)
150149
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config.lang_tgt)
151150

@@ -167,9 +166,8 @@ def get_ds(config: ModelConfig):
167166
max_len_src = max(max_len_src, len(src_ids))
168167
max_len_tgt = max(max_len_tgt, len(tgt_ids))
169168

170-
if config.local_rank == 0:
171-
print(f'Max length of source sentence: {max_len_src}')
172-
print(f'Max length of target sentence: {max_len_tgt}')
169+
print(f'GPU {config.local_rank} - Max length of source sentence: {max_len_src}')
170+
print(f'GPU {config.local_rank} - Max length of target sentence: {max_len_tgt}')
173171

174172

175173
train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=False, sampler=DistributedSampler(train_ds, shuffle=True))
@@ -185,15 +183,13 @@ def train_model(config: ModelConfig):
185183
# Define the device
186184
assert torch.cuda.is_available(), "Training on CPU is not supported"
187185
device = torch.device("cuda")
188-
if config.local_rank == 0:
189-
print("Using device:", device)
186+
print(f"GPU {config.local_rank} - Using device: {device}")
190187

191188
# Make sure the weights folder exists
192189
Path(config.model_folder).mkdir(parents=True, exist_ok=True)
193190

194191
# Load the dataset
195-
if config.local_rank == 0:
196-
print("Loading dataset...")
192+
print(f"GPU {config.local_rank} - Loading dataset...")
197193
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
198194
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
199195

@@ -213,8 +209,7 @@ def train_model(config: ModelConfig):
213209
model_filename = get_weights_file_path(config, int(config.preload))
214210

215211
if model_filename is not None:
216-
if config.local_rank == 0:
217-
print(f'Preloading model {model_filename}')
212+
print(f'GPU {config.local_rank} - Preloading model {model_filename}')
218213
state = torch.load(model_filename)
219214
model.load_state_dict(state['model_state_dict'])
220215
initial_epoch = state['epoch'] + 1
@@ -224,8 +219,7 @@ def train_model(config: ModelConfig):
224219
del state
225220
else:
226221
# If we couldn't find a model to preload, just start from scratch
227-
if config.local_rank == 0:
228-
print(f'Could not find model to preload: {config.preload}. Starting from scratch')
222+
print(f'GPU {config.local_rank} - Could not find model to preload: {config.preload}. Starting from scratch')
229223

230224
# Only initialize W&B on the global rank 0 node
231225
if config.global_rank == 0:
@@ -255,12 +249,11 @@ def train_model(config: ModelConfig):
255249
for epoch in range(initial_epoch, config.num_epochs):
256250
torch.cuda.empty_cache()
257251
model.train()
258-
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d} on rank {config.global_rank}")
259-
if config.local_rank != 0:
260-
batch_iterator.disable = True
261252

262-
for batch in batch_iterator:
253+
# Disable tqdm on all nodes except the rank 0 GPU on each server
254+
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d} on rank {config.global_rank}", disable=config.local_rank != 0)
263255

256+
for batch in batch_iterator:
264257
encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
265258
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
266259
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
@@ -336,7 +329,7 @@ def train_model(config: ModelConfig):
336329
assert config.local_rank != -1, "LOCAL_RANK environment variable not set"
337330
assert config.global_rank != -1, "RANK environment variable not set"
338331

339-
# Print configuration
332+
# Print configuration (only once per server)
340333
if config.local_rank == 0:
341334
print("Configuration:")
342335
for key, value in config.__dict__.items():

0 commit comments

Comments
 (0)