@@ -67,52 +67,67 @@ def get_kfold_ds(fold, source_imgs, target_imgs):
67
67
return train_ds , valid_ds
68
68
69
69
def valid (model , valid_ds , shuffle = False ):
70
- valid_dl = DataLoader (valid_ds , batch_size = config .BATCH_SIZE , shuffle = shuffle )
70
+ valid_dl = DataLoader (valid_ds , batch_size = config .BATCH_SIZE , num_workers = int (os .cpu_count ()/ 8 ), shuffle = shuffle )
71
+ criterion = nn .MSELoss (reduction = 'mean' )
71
72
losses = []
72
73
with torch .no_grad ():
73
74
model .eval ()
74
75
with tqdm (valid_dl , desc = 'Eval' , miniters = 10 ) as progress :
75
76
for i , (source , target ) in enumerate (progress ):
77
+ source = source .to (config .DEVICE )
78
+ target = target .to (config .DEVICE )
76
79
with autocast ():
77
80
img_pred = model (source )
78
- ssim_loss = 1 - ssim (img_pred , target )
79
- losses .append (ssim_loss )
81
+ ssim_loss = criterion (img_pred , target )
82
+ # ssim_loss = 1 - ssim(img_pred, target)
83
+ # losses.append(ssim_loss)
84
+ progress .set_description (f'Valid loss: { ssim_loss :.02f} ' )
85
+
80
86
return np .mean (losses )
81
87
82
88
83
89
84
90
85
91
86
- def train (train_ds , logger , name ):
92
+ def train (train_ds , valid_ds , logger , name ):
87
93
print (len (train_ds ))
88
- set_seed (123 )
89
- train_dl = DataLoader (train_ds , batch_size = config .BATCH_SIZE , shuffle = True )
94
+ set_seed (11 )
95
+ train_dl = DataLoader (train_ds , batch_size = config .BATCH_SIZE , num_workers = int ( os . cpu_count () / 8 ), shuffle = True )
90
96
model = fujiModel ()
91
97
model = model .to (config .DEVICE )
92
98
optim = torch .optim .Adam (model .parameters ())
93
99
scheduler = torch .optim .lr_scheduler .OneCycleLR (optim , max_lr = config .ONE_CYCLE_MAX_LR , epochs = config .NUM_EPOCHS , steps_per_epoch = len (train_dl ))
94
100
scaler = GradScaler ()
101
+
102
+ criterion = nn .MSELoss (reduction = 'mean' )
103
+
95
104
for epoch in tqdm (range (config .NUM_EPOCHS )):
96
- for batch_idx , (source , target ) in enumerate (train_dl ):
97
- optim .zero_grad ()
98
- source = source .to (config .DEVICE )
99
- target = target .to (config .DEVICE )
100
- with autocast ():
101
- img_pred = model (source )
102
- ssim_loss = 1 - ssim (img_pred , target )
103
- print (ssim_loss .is_cuda )
104
- if torch .isinf (ssim_loss ).any () or torch .isnan (ssim_loss ).any ():
105
- print (f'Bad loss, skipping the batch { batch_idx } ' )
106
- del ssim_loss , img_pred
107
- gc_collect ()
108
- continue
109
-
110
- # scaler is needed to prevent "gradient underflow"
111
- scaler .scale (ssim_loss ).backward ()
112
- scaler .step (optim )
113
- scaler .update ()
114
- logger .log ({'loss' : (ssim_loss ), 'lr' : scheduler .get_last_lr ()[0 ]})
105
+ with tqdm (train_dl , desc = 'Train' , miniters = 10 ) as progress :
106
+ for batch_idx , (source , target ) in enumerate (progress ):
107
+ optim .zero_grad ()
108
+ source = source .to (config .DEVICE )
109
+ target = target .to (config .DEVICE )
110
+ with autocast ():
111
+ img_pred = model (source )
112
+ # ssim_loss = 1 - ssim(img_pred, target)
113
+ ssim_loss = criterion (img_pred , target )
114
+ if torch .isinf (ssim_loss ).any () or torch .isnan (ssim_loss ).any ():
115
+ print (ssim_loss )
116
+ print (f'Bad loss: { ssim_loss } , skipping the batch { batch_idx } ' )
117
+ del ssim_loss , img_pred
118
+ gc_collect ()
119
+ continue
120
+
121
+ # scaler is needed to prevent "gradient underflow"
122
+ scaler .scale (ssim_loss ).backward ()
123
+ scaler .step (optim )
124
+ scaler .update ()
125
+ # optim.steap()
126
+ logger .log ({'loss' : (ssim_loss ), 'lr' : scheduler .get_last_lr ()[0 ]})
127
+ progress .set_description (f'Train loss: { ssim_loss :.02f} ' )
128
+
115
129
scheduler .step ()
130
+ valid (model , valid_ds )
116
131
save_model (name , model )
117
132
return model
118
133
@@ -140,7 +155,7 @@ def main():
140
155
name = f'fakeji-fold{ fold } '
141
156
with wandb .init (project = 'fakeji' , name = name , entity = 'jimmydut' ) as run :
142
157
gc_collect ()
143
- models .append (train (train_ds , run , name ))
158
+ models .append (train (train_ds , valid_ds , run , name ))
144
159
145
160
146
161
if __name__ == '__main__' :
0 commit comments