Skip to content

Commit 74974cd

Browse files
committed
update
1 parent 99d2eb3 commit 74974cd

File tree

6 files changed

+8
-5
lines changed

6 files changed

+8
-5
lines changed

__pycache__/config.cpython-39.pyc

-1.08 KB
Binary file not shown.

__pycache__/model.cpython-39.pyc

-1.24 KB
Binary file not shown.

__pycache__/ssim.cpython-39.pyc

-2.54 KB
Binary file not shown.

__pycache__/utils.cpython-39.pyc

-1.45 KB
Binary file not shown.

config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self, ):
1313
print(f'Batch Size: {self.BATCH_SIZE}')
1414
self.ONE_CYCLE_MAX_LR = 0.0001
1515
self.MODEL_PATH = Path('model')
16+
self.check_path(self.MODEL_PATH)
1617

1718
def get_batch_size(self, ):
1819
if self.DEVICE == 'cuda':
@@ -25,4 +26,7 @@ def get_batch_size(self, ):
2526
BATCH_SIZE = 4
2627
else:
2728
BATCH_SIZE = 2
28-
return BATCH_SIZE
29+
return BATCH_SIZE
30+
31+
def check_path(self, path):
32+
path.mkdir(exist_ok=True)

train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def train(train_ds, logger, name):
9191
optim = torch.optim.Adam(model.parameters())
9292
scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=config.ONE_CYCLE_MAX_LR, epochs=config.NUM_EPOCHS, steps_per_epoch=len(train_dl))
9393
scaler = GradScaler()
94-
with tqdm(train_dl, desc='Train', miniters=10) as progress:
95-
for batch_idx, (source, target) in enumerate(progress):
94+
for epoch in tqdm(range(config.NUM_EPOCHS)):
95+
for batch_idx, (source, target) in enumerate(train_dl):
9696
optim.zero_grad()
9797
source = source.to(config.DEVICE)
9898
target = target.to(config.DEVICE)
@@ -110,9 +110,8 @@ def train(train_ds, logger, name):
110110
scaler.scale(ssim_loss).backward()
111111
scaler.step(optim)
112112
scaler.update()
113-
scheduler.step()
114-
progress.set_description(f'Train loss: {ssim_loss :.02f}')
115113
logger.log({'loss': (ssim_loss), 'lr': scheduler.get_last_lr()[0]})
114+
scheduler.step()
116115
save_model(name, model)
117116
return model
118117

0 commit comments

Comments
 (0)