Skip to content

Commit 687c455

Browse files
committed
Fix bugs; rename variables (properties of class)
1 parent 492541c commit 687c455

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

pyTorchAutoForge/optimization/ModelTrainingManager.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def __init__(self, model: nn.Module | None,
394394
"Neither model nor model checkpoint path provided. Cannot continue with optimization process.")
395395

396396
self.model: torch.nn.Module = (model).to(self.device)
397-
self.bestModel: torch.nn.Module | None = None
397+
398+
self.best_model: torch.nn.Module | None = None
399+
self.best_epoch: int = 0
400+
398401
self.loss_fcn: torch.nn.Module = lossFcn
399402

400403
self.trainingDataloader: torch.utils.data.Dataloader | None = None
@@ -406,6 +409,7 @@ def __init__(self, model: nn.Module | None,
406409
self.current_epoch: int = 0
407410
self.num_of_updates: int = 0
408411

412+
409413
self.currentTrainingLoss: float | None = None
410414
self.currentValidationLoss: float | None = None
411415
self.currentMlflowRun = mlflow.active_run() # Returns None if no active run
@@ -581,7 +585,7 @@ def get_traced_model(self, device=None):
581585
device = self.device
582586

583587
# Get internal model (best or model)
584-
model = self.bestModel if (self.bestModel is not None) else self.model
588+
model = self.best_model if (self.best_model is not None) else self.model
585589

586590
try:
587591
raise NotImplementedError('Method not implemented yet.')
@@ -992,9 +996,10 @@ def trainAndValidate(self):
992996
if isinstance(validation_loss_value, tuple):
993997
validation_loss_value = validation_loss_value[0]
994998

995-
self.currentValidationLoss: float = validation_loss_value
996-
self.bestValidationLoss: float = validation_loss_value
997-
self.bestModel: torch.nn.Module | None = copy.deepcopy(self.model).to('cpu')
999+
self.currentValidationLoss = validation_loss_value
1000+
self.bestValidationLoss = validation_loss_value
1001+
self.best_model = copy.deepcopy(self.model).to('cpu')
1002+
self.best_epoch = self.current_epoch
9981003

9991004
########################################
10001005
### Loop over epochs
@@ -1068,7 +1073,7 @@ def trainAndValidate(self):
10681073

10691074
# Update stats if new best model found (independently of keep_best flag)
10701075
if tmp_valid_loss <= self.bestValidationLoss:
1071-
self.bestEpoch = epoch_num
1076+
self.best_epoch = epoch_num
10721077
self.bestValidationLoss = tmp_valid_loss
10731078
no_new_best_counter = 0
10741079
else:
@@ -1080,7 +1085,7 @@ def trainAndValidate(self):
10801085
if tmp_valid_loss <= self.bestValidationLoss:
10811086

10821087
# Transfer best model to CPU to avoid additional memory allocation on GPU
1083-
self.bestModel: torch.nn.Module | None = copy.deepcopy(
1088+
self.best_model: torch.nn.Module | None = copy.deepcopy(
10841089
self.model).to('cpu')
10851090

10861091
# Delete previous best model checkpoint if it exists
@@ -1099,10 +1104,10 @@ def trainAndValidate(self):
10991104

11001105
# Save temporary best model
11011106
model_save_name = os.path.join(
1102-
self.checkpoint_dir, self.modelName + f"_epoch_{self.bestEpoch}")
1107+
self.checkpoint_dir, self.modelName + f"_epoch_{self.best_epoch}")
11031108

1104-
if self.bestModel is not None:
1105-
SaveModel(model=self.bestModel, model_filename=model_save_name,
1109+
if self.best_model is not None:
1110+
SaveModel(model=self.best_model, model_filename=model_save_name,
11061111
save_mode=AutoForgeModuleSaveMode.MODEL_ARCH_STATE,
11071112
target_device='cpu')
11081113

@@ -1125,7 +1130,7 @@ def trainAndValidate(self):
11251130
'num_of_updates', self.num_of_updates, step=self.current_epoch)
11261131

11271132
print('\tCurrent best at epoch {best_epoch}, with validation loss: {best_loss:.06g}'.format(
1128-
best_epoch=self.bestEpoch, best_loss=self.bestValidationLoss))
1133+
best_epoch=self.best_epoch, best_loss=self.bestValidationLoss))
11291134
print(
11301135
f'\tEpoch cycle duration: {((time.time() - cycle_start_time) / 60):.4f} [min]')
11311136

@@ -1150,31 +1155,31 @@ def trainAndValidate(self):
11501155

11511156
if self.keep_best:
11521157
print('Best model saved from epoch: {best_epoch} with validation loss: {best_loss:.4f}'.format(
1153-
best_epoch=self.bestEpoch, best_loss=self.bestValidationLoss))
1158+
best_epoch=self.best_epoch, best_loss=self.bestValidationLoss))
11541159

11551160
with torch.no_grad():
11561161
examplePair = next(iter(self.validationDataloader))
11571162
model_save_name = os.path.join(
1158-
self.checkpoint_dir, self.modelName + f"_epoch_{self.bestEpoch}")
1163+
self.checkpoint_dir, self.modelName + f"_epoch_{self.best_epoch}")
11591164

1160-
if self.bestModel is not None:
1161-
SaveModel(model=self.bestModel, model_filename=model_save_name,
1165+
if self.best_model is not None:
1166+
SaveModel(model=self.best_model, model_filename=model_save_name,
11621167
save_mode=AutoForgeModuleSaveMode.MODEL_ARCH_STATE,
11631168
example_input=examplePair[0],
11641169
target_device=self.device)
11651170
else:
11661171
print(
1167-
"\033[38;5;208mWARNING: SaveModel skipped due to bestModel being None!\033[0m")
1172+
"\033[38;5;208mWARNING: SaveModel skipped due to best_model being None!\033[0m")
11681173

11691174
if self.mlflow_logging:
1170-
mlflow.log_param('checkpoint_best_epoch', self.bestEpoch)
1175+
mlflow.log_param('checkpoint_best_epoch', self.best_epoch)
11711176

11721177
# Post-training operations
11731178
print('Training and validation cycle completed.')
11741179
if self.mlflow_logging:
11751180
mlflow.end_run(status='FINISHED')
11761181

1177-
return self.bestModel if self.keep_best else self.model
1182+
return self.best_model if self.keep_best else self.model
11781183

11791184
except KeyboardInterrupt:
11801185
print(
@@ -1183,13 +1188,13 @@ def trainAndValidate(self):
11831188
# Save best model up to current epoch if not None
11841189
try:
11851190
# TODO replace by a trainer export method
1186-
if self.bestModel is not None:
1191+
if self.best_model is not None:
11871192
examplePair = next(iter(self.validationDataloader))
11881193
model_save_name = os.path.join(
1189-
self.checkpoint_dir, self.modelName + f"_epoch_{self.bestEpoch}")
1194+
self.checkpoint_dir, self.modelName + f"_epoch_{self.best_epoch}")
11901195

1191-
if self.bestModel is not None:
1192-
SaveModel(model=self.bestModel, model_filename=model_save_name,
1196+
if self.best_model is not None:
1197+
SaveModel(model=self.best_model, model_filename=model_save_name,
11931198
save_mode=AutoForgeModuleSaveMode.MODEL_ARCH_STATE,
11941199
example_input=examplePair[0],
11951200
target_device=self.device)
@@ -1938,8 +1943,8 @@ def TrainAndValidateModel(dataloaderIndex: DataloaderIndex,
19381943
bestSWAvalidationLoss = 1E10
19391944

19401945
# Deep copy the initial state of the model and move it to the CPU
1941-
bestModel = copy.deepcopy(model).to('cpu')
1942-
bestEpoch = epochStart
1946+
best_model = copy.deepcopy(model).to('cpu')
1947+
best_epoch = epochStart
19431948

19441949
if swa_model != None:
19451950
bestSWAmodel = copy.deepcopy(model).to('cpu')
@@ -1962,15 +1967,15 @@ def TrainAndValidateModel(dataloaderIndex: DataloaderIndex,
19621967
# If validation loss is better than previous best, update best model
19631968
if validationLossHistory[epochID] < bestValidationLoss:
19641969
# Replace best model with current model
1965-
bestModel = copy.deepcopy(model).to('cpu')
1966-
bestEpoch = epochID + epochStart
1970+
best_model = copy.deepcopy(model).to('cpu')
1971+
best_epoch = epochID + epochStart
19671972
bestValidationLoss = validationLossHistory[epochID]
19681973

1969-
bestModelData = {'model': bestModel, 'epoch': bestEpoch,
1974+
bestModelData = {'model': best_model, 'epoch': best_epoch,
19701975
'validationLoss': bestValidationLoss}
19711976

19721977
print(
1973-
f"Current best model found at epoch: {bestEpoch} with validation loss: {bestValidationLoss}")
1978+
f"Current best model found at epoch: {best_epoch} with validation loss: {bestValidationLoss}")
19741979

19751980
# SWA handling: if enabled, evaluate validation loss of SWA model, then decide if to update or reset
19761981
if swa_model != None and epochID >= swa_start_epoch:

0 commit comments

Comments
 (0)