Skip to content

Commit 7963e84

Browse files
committed
updatre
1 parent 924eb0e commit 7963e84

File tree

5 files changed

+129
-114
lines changed

5 files changed

+129
-114
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
__pycache__/
2+
.ipynb_checkpoints/
23
wandb/
34
data/
45
data_gen/

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
class Config:
77
def __init__(self, ):
8-
self.DATA_DIR = Path('data')
8+
self.DATA_DIR = Path('data_gen')
99
self.SOURCE_DIR = self.DATA_DIR / 'source'
1010
self.TARGET_DIR = self.DATA_DIR / 'target'
11-
self.NUM_EPOCHS = 100
11+
self.NUM_EPOCHS = 50
1212
self.N_FOLD = 5
1313
self.CROP_RATIO = 12
1414
self.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

train.ipynb renamed to gen.ipynb

Lines changed: 71 additions & 80 deletions
Large diffs are not rendered by default.

test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import torch
22
import torch.nn as nn
3-
3+
from ssim import ssim
44

55
class fujiModel(nn.Module):
66
def __init__(self):
77
super().__init__()
8-
self.fc1 = nn.Conv2d(3, 3, (3, 3))
8+
self.fc1 = nn.Linear(4896, 4896)
9+
self.sigmoid = nn.Sigmoid()
10+
self.fc2 = nn.Linear(4896, 4896)
11+
912
def forward(self, x):
10-
x = self.fc1(x)
11-
return x
13+
out = self.fc1(x)
14+
out = self.sigmoid(out)
15+
out = self.fc2(out)
16+
out = self.sigmoid(out)
17+
return out.mul(x)
1218

13-
test = torch.randn((32, 3, 240, 128))
19+
test = torch.randn((4, 3, 4896, 4896))
1420
model = fujiModel()
15-
print(model(test).size())
21+
pred = model(test)
22+
loss = 1 - ssim(test, pred)
23+
print(loss)

train.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,52 +67,67 @@ def get_kfold_ds(fold, source_imgs, target_imgs):
6767
return train_ds, valid_ds
6868

6969
def valid(model, valid_ds, shuffle=False):
70-
valid_dl = DataLoader(valid_ds, batch_size=config.BATCH_SIZE, shuffle=shuffle)
70+
valid_dl = DataLoader(valid_ds, batch_size=config.BATCH_SIZE, num_workers=int(os.cpu_count()/8), shuffle=shuffle)
71+
criterion = nn.MSELoss(reduction='mean')
7172
losses = []
7273
with torch.no_grad():
7374
model.eval()
7475
with tqdm(valid_dl, desc='Eval', miniters=10) as progress:
7576
for i, (source, target) in enumerate(progress):
77+
source = source.to(config.DEVICE)
78+
target = target.to(config.DEVICE)
7679
with autocast():
7780
img_pred = model(source)
78-
ssim_loss = 1 - ssim(img_pred, target)
79-
losses.append(ssim_loss)
81+
ssim_loss = criterion(img_pred, target)
82+
# ssim_loss = 1 - ssim(img_pred, target)
83+
# losses.append(ssim_loss)
84+
progress.set_description(f'Valid loss: {ssim_loss :.02f}')
85+
8086
return np.mean(losses)
8187

8288

8389

8490

8591

86-
def train(train_ds, logger, name):
92+
def train(train_ds, valid_ds, logger, name):
8793
print(len(train_ds))
88-
set_seed(123)
89-
train_dl = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True)
94+
set_seed(11)
95+
train_dl = DataLoader(train_ds, batch_size=config.BATCH_SIZE, num_workers=int(os.cpu_count()/8), shuffle=True)
9096
model = fujiModel()
9197
model = model.to(config.DEVICE)
9298
optim = torch.optim.Adam(model.parameters())
9399
scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=config.ONE_CYCLE_MAX_LR, epochs=config.NUM_EPOCHS, steps_per_epoch=len(train_dl))
94100
scaler = GradScaler()
101+
102+
criterion = nn.MSELoss(reduction='mean')
103+
95104
for epoch in tqdm(range(config.NUM_EPOCHS)):
96-
for batch_idx, (source, target) in enumerate(train_dl):
97-
optim.zero_grad()
98-
source = source.to(config.DEVICE)
99-
target = target.to(config.DEVICE)
100-
with autocast():
101-
img_pred = model(source)
102-
ssim_loss = 1 - ssim(img_pred, target)
103-
print(ssim_loss.is_cuda)
104-
if torch.isinf(ssim_loss).any() or torch.isnan(ssim_loss).any():
105-
print(f'Bad loss, skipping the batch {batch_idx}')
106-
del ssim_loss, img_pred
107-
gc_collect()
108-
continue
109-
110-
# scaler is needed to prevent "gradient underflow"
111-
scaler.scale(ssim_loss).backward()
112-
scaler.step(optim)
113-
scaler.update()
114-
logger.log({'loss': (ssim_loss), 'lr': scheduler.get_last_lr()[0]})
105+
with tqdm(train_dl, desc='Train', miniters=10) as progress:
106+
for batch_idx, (source, target) in enumerate(progress):
107+
optim.zero_grad()
108+
source = source.to(config.DEVICE)
109+
target = target.to(config.DEVICE)
110+
with autocast():
111+
img_pred = model(source)
112+
# ssim_loss = 1 - ssim(img_pred, target)
113+
ssim_loss = criterion(img_pred, target)
114+
if torch.isinf(ssim_loss).any() or torch.isnan(ssim_loss).any():
115+
print(ssim_loss)
116+
print(f'Bad loss: {ssim_loss}, skipping the batch {batch_idx}')
117+
del ssim_loss, img_pred
118+
gc_collect()
119+
continue
120+
121+
# scaler is needed to prevent "gradient underflow"
122+
scaler.scale(ssim_loss).backward()
123+
scaler.step(optim)
124+
scaler.update()
125+
# optim.steap()
126+
logger.log({'loss': (ssim_loss), 'lr': scheduler.get_last_lr()[0]})
127+
progress.set_description(f'Train loss: {ssim_loss :.02f}')
128+
115129
scheduler.step()
130+
valid(model, valid_ds)
116131
save_model(name, model)
117132
return model
118133

@@ -140,7 +155,7 @@ def main():
140155
name = f'fakeji-fold{fold}'
141156
with wandb.init(project='fakeji', name=name, entity='jimmydut') as run:
142157
gc_collect()
143-
models.append(train(train_ds, run, name))
158+
models.append(train(train_ds, valid_ds, run, name))
144159

145160

146161
if __name__ == '__main__':

0 commit comments

Comments
 (0)