@@ -394,7 +394,10 @@ def __init__(self, model: nn.Module | None,
394394 "Neither model nor model checkpoint path provided. Cannot continue with optimization process." )
395395
396396 self .model : torch .nn .Module = (model ).to (self .device )
397- self .bestModel : torch .nn .Module | None = None
397+
398+ self .best_model : torch .nn .Module | None = None
399+ self .best_epoch : int = 0
400+
398401 self .loss_fcn : torch .nn .Module = lossFcn
399402
400403 self .trainingDataloader : torch .utils .data .Dataloader | None = None
@@ -406,6 +409,7 @@ def __init__(self, model: nn.Module | None,
406409 self .current_epoch : int = 0
407410 self .num_of_updates : int = 0
408411
412+
409413 self .currentTrainingLoss : float | None = None
410414 self .currentValidationLoss : float | None = None
411415 self .currentMlflowRun = mlflow .active_run () # Returns None if no active run
@@ -581,7 +585,7 @@ def get_traced_model(self, device=None):
581585 device = self .device
582586
583587 # Get internal model (best or model)
584- model = self .bestModel if (self .bestModel is not None ) else self .model
588+ model = self .best_model if (self .best_model is not None ) else self .model
585589
586590 try :
587591 raise NotImplementedError ('Method not implemented yet.' )
@@ -992,9 +996,10 @@ def trainAndValidate(self):
992996 if isinstance (validation_loss_value , tuple ):
993997 validation_loss_value = validation_loss_value [0 ]
994998
995- self .currentValidationLoss : float = validation_loss_value
996- self .bestValidationLoss : float = validation_loss_value
997- self .bestModel : torch .nn .Module | None = copy .deepcopy (self .model ).to ('cpu' )
999+ self .currentValidationLoss = validation_loss_value
1000+ self .bestValidationLoss = validation_loss_value
1001+ self .best_model = copy .deepcopy (self .model ).to ('cpu' )
1002+ self .best_epoch = self .current_epoch
9981003
9991004 ########################################
10001005 ### Loop over epochs
@@ -1068,7 +1073,7 @@ def trainAndValidate(self):
10681073
10691074 # Update stats if new best model found (independently of keep_best flag)
10701075 if tmp_valid_loss <= self .bestValidationLoss :
1071- self .bestEpoch = epoch_num
1076+ self .best_epoch = epoch_num
10721077 self .bestValidationLoss = tmp_valid_loss
10731078 no_new_best_counter = 0
10741079 else :
@@ -1080,7 +1085,7 @@ def trainAndValidate(self):
10801085 if tmp_valid_loss <= self .bestValidationLoss :
10811086
10821087 # Transfer best model to CPU to avoid additional memory allocation on GPU
1083- self .bestModel : torch .nn .Module | None = copy .deepcopy (
1088+ self .best_model : torch .nn .Module | None = copy .deepcopy (
10841089 self .model ).to ('cpu' )
10851090
10861091 # Delete previous best model checkpoint if it exists
@@ -1099,10 +1104,10 @@ def trainAndValidate(self):
10991104
11001105 # Save temporary best model
11011106 model_save_name = os .path .join (
1102- self .checkpoint_dir , self .modelName + f"_epoch_{ self .bestEpoch } " )
1107+ self .checkpoint_dir , self .modelName + f"_epoch_{ self .best_epoch } " )
11031108
1104- if self .bestModel is not None :
1105- SaveModel (model = self .bestModel , model_filename = model_save_name ,
1109+ if self .best_model is not None :
1110+ SaveModel (model = self .best_model , model_filename = model_save_name ,
11061111 save_mode = AutoForgeModuleSaveMode .MODEL_ARCH_STATE ,
11071112 target_device = 'cpu' )
11081113
@@ -1125,7 +1130,7 @@ def trainAndValidate(self):
11251130 'num_of_updates' , self .num_of_updates , step = self .current_epoch )
11261131
11271132 print ('\t Current best at epoch {best_epoch}, with validation loss: {best_loss:.06g}' .format (
1128- best_epoch = self .bestEpoch , best_loss = self .bestValidationLoss ))
1133+ best_epoch = self .best_epoch , best_loss = self .bestValidationLoss ))
11291134 print (
11301135 f'\t Epoch cycle duration: { ((time .time () - cycle_start_time ) / 60 ):.4f} [min]' )
11311136
@@ -1150,31 +1155,31 @@ def trainAndValidate(self):
11501155
11511156 if self .keep_best :
11521157 print ('Best model saved from epoch: {best_epoch} with validation loss: {best_loss:.4f}' .format (
1153- best_epoch = self .bestEpoch , best_loss = self .bestValidationLoss ))
1158+ best_epoch = self .best_epoch , best_loss = self .bestValidationLoss ))
11541159
11551160 with torch .no_grad ():
11561161 examplePair = next (iter (self .validationDataloader ))
11571162 model_save_name = os .path .join (
1158- self .checkpoint_dir , self .modelName + f"_epoch_{ self .bestEpoch } " )
1163+ self .checkpoint_dir , self .modelName + f"_epoch_{ self .best_epoch } " )
11591164
1160- if self .bestModel is not None :
1161- SaveModel (model = self .bestModel , model_filename = model_save_name ,
1165+ if self .best_model is not None :
1166+ SaveModel (model = self .best_model , model_filename = model_save_name ,
11621167 save_mode = AutoForgeModuleSaveMode .MODEL_ARCH_STATE ,
11631168 example_input = examplePair [0 ],
11641169 target_device = self .device )
11651170 else :
11661171 print (
1167- "\033 [38;5;208mWARNING: SaveModel skipped due to bestModel being None!\033 [0m" )
1172+ "\033 [38;5;208mWARNING: SaveModel skipped due to best_model being None!\033 [0m" )
11681173
11691174 if self .mlflow_logging :
1170- mlflow .log_param ('checkpoint_best_epoch' , self .bestEpoch )
1175+ mlflow .log_param ('checkpoint_best_epoch' , self .best_epoch )
11711176
11721177 # Post-training operations
11731178 print ('Training and validation cycle completed.' )
11741179 if self .mlflow_logging :
11751180 mlflow .end_run (status = 'FINISHED' )
11761181
1177- return self .bestModel if self .keep_best else self .model
1182+ return self .best_model if self .keep_best else self .model
11781183
11791184 except KeyboardInterrupt :
11801185 print (
@@ -1183,13 +1188,13 @@ def trainAndValidate(self):
11831188 # Save best model up to current epoch if not None
11841189 try :
11851190 # TODO replace by a trainer export method
1186- if self .bestModel is not None :
1191+ if self .best_model is not None :
11871192 examplePair = next (iter (self .validationDataloader ))
11881193 model_save_name = os .path .join (
1189- self .checkpoint_dir , self .modelName + f"_epoch_{ self .bestEpoch } " )
1194+ self .checkpoint_dir , self .modelName + f"_epoch_{ self .best_epoch } " )
11901195
1191- if self .bestModel is not None :
1192- SaveModel (model = self .bestModel , model_filename = model_save_name ,
1196+ if self .best_model is not None :
1197+ SaveModel (model = self .best_model , model_filename = model_save_name ,
11931198 save_mode = AutoForgeModuleSaveMode .MODEL_ARCH_STATE ,
11941199 example_input = examplePair [0 ],
11951200 target_device = self .device )
@@ -1938,8 +1943,8 @@ def TrainAndValidateModel(dataloaderIndex: DataloaderIndex,
19381943 bestSWAvalidationLoss = 1E10
19391944
19401945 # Deep copy the initial state of the model and move it to the CPU
1941- bestModel = copy .deepcopy (model ).to ('cpu' )
1942- bestEpoch = epochStart
1946+ best_model = copy .deepcopy (model ).to ('cpu' )
1947+ best_epoch = epochStart
19431948
19441949 if swa_model != None :
19451950 bestSWAmodel = copy .deepcopy (model ).to ('cpu' )
@@ -1962,15 +1967,15 @@ def TrainAndValidateModel(dataloaderIndex: DataloaderIndex,
19621967 # If validation loss is better than previous best, update best model
19631968 if validationLossHistory [epochID ] < bestValidationLoss :
19641969 # Replace best model with current model
1965- bestModel = copy .deepcopy (model ).to ('cpu' )
1966- bestEpoch = epochID + epochStart
1970+ best_model = copy .deepcopy (model ).to ('cpu' )
1971+ best_epoch = epochID + epochStart
19671972 bestValidationLoss = validationLossHistory [epochID ]
19681973
1969- bestModelData = {'model' : bestModel , 'epoch' : bestEpoch ,
1974+ bestModelData = {'model' : best_model , 'epoch' : best_epoch ,
19701975 'validationLoss' : bestValidationLoss }
19711976
19721977 print (
1973- f"Current best model found at epoch: { bestEpoch } with validation loss: { bestValidationLoss } " )
1978+ f"Current best model found at epoch: { best_epoch } with validation loss: { bestValidationLoss } " )
19741979
19751980 # SWA handling: if enabled, evaluate validation loss of SWA model, then decide if to update or reset
19761981 if swa_model != None and epochID >= swa_start_epoch :
0 commit comments