Skip to content

Commit 2ba52ca

Browse files
committed
save, test
1 parent 6ab7185 commit 2ba52ca

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

train_segment.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,14 @@
2727
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
2828
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
2929

30-
parser.add_argument("--begin_epoch", type=int, default=0, help="number of gpu")
31-
parser.add_argument("--end_epoch", type=int, default=31, help="number of gpu")
30+
parser.add_argument("--begin_epoch", type=int, default=0, help="begin_epoch")
31+
parser.add_argument("--end_epoch", type=int, default=51, help="end_epoch")
32+
33+
parser.add_argument("--need_test", type=bool, default=True, help="need to test")
3234
parser.add_argument("--test_interval", type=int, default=10, help="interval of test")
33-
parser.add_argument("--save_interval", type=int, default=10, help="interval of test")
35+
parser.add_argument("--need_save", type=bool, default=True, help="need to save")
36+
parser.add_argument("--save_interval", type=int, default=10, help="interval of save weights")
37+
3438

3539
parser.add_argument("--img_height", type=int, default=704, help="size of image height")
3640
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
@@ -52,34 +56,34 @@
5256

5357
# Build nets
5458
segment_net = SegmentNet(init_weights=True)
55-
decision_net = DecisionNet(init_weights=True)
59+
#decision_net = DecisionNet(init_weights=True)
5660

5761
# Loss functions
5862
criterion_segment = torch.nn.MSELoss()
59-
criterion_decision = torch.nn.L1Loss()
63+
#criterion_decision = torch.nn.L1Loss()
6064

6165
if opt.cuda:
6266
segment_net = segment_net.cuda()
63-
decision_net = decision_net.cuda()
67+
#decision_net = decision_net.cuda()
6468
criterion_segment.cuda()
65-
criterion_decision.cuda()
69+
#criterion_decision.cuda()
6670

6771
if opt.gpu_num > 1:
6872
segment_net = torch.nn.DataParallel(segment_net, device_ids=list(range(opt.gpu_num)))
69-
decision_net = torch.nn.DataParallel(decision_net, device_ids=list(range(opt.gpu_num)))
73+
# decision_net = torch.nn.DataParallel(decision_net, device_ids=list(range(opt.gpu_num)))
7074

7175
if opt.begin_epoch != 0:
7276
# Load pretrained models
73-
segment_net.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
74-
decision_net.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
77+
segment_net.load_state_dict(torch.load("./saved_models/segment_net_%d.pth" % (opt.begin_epoch)))
78+
# decision_net.load_state_dict(torch.load("./saved_models/decision_net_%d.pth" % (opt.begin_epoch)))
7579
else:
7680
# Initialize weights
7781
segment_net.apply(weights_init_normal)
78-
decision_net.apply(weights_init_normal)
82+
# decision_net.apply(weights_init_normal)
7983

8084
# Optimizers
8185
optimizer_seg = torch.optim.Adam(segment_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
82-
optimizer_dec = torch.optim.Adam(decision_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
86+
#optimizer_dec = torch.optim.Adam(decision_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
8387

8488
transforms_ = transforms.Compose([
8589
transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
@@ -125,7 +129,7 @@
125129
lenNum = 2*(lenNum-1)
126130

127131
segment_net.train()
128-
132+
# train *****************************************************************
129133
for i in range(0, lenNum):
130134
if i % 2 == 0:
131135
batchData = iterOK.__next__()
@@ -161,10 +165,11 @@
161165
loss_seg.item()
162166
)
163167
)
164-
165-
if epoch % opt.test_interval == 0 and epoch >= opt.test_interval:
168+
169+
# test ****************************************************************************
170+
if opt.need_test and epoch % opt.test_interval == 0 and epoch >= opt.test_interval:
166171
segment_net.eval()
167-
172+
168173
for i, testBatch in enumerate(testloader):
169174
imgTest = testBatch["img"].cuda()
170175
t1 = time.time()
@@ -180,9 +185,18 @@
180185
print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
181186
save_image(imgTest.data, "%s/img_%d.jpg"% (save_path_str, i))
182187
save_image(segTest.data, "%s/img_%d_seg.jpg"% (save_path_str, i))
188+
189+
segment_net.train()
190+
191+
# save parameters *****************************************************************
192+
if opt.need_save and epoch % opt.save_interval == 0 and epoch >= opt.save_interval:
193+
segment_net.eval()
183194

195+
save_path_str = "./saved_models"
196+
if os.path.exists(save_path_str) == False:
197+
os.makedirs(save_path_str, exist_ok=True)
184198

185-
if epoch % opt.save_interval == 0 and epoch >= opt.save_interval:
186-
187-
199+
torch.save(segment_net.state_dict(), "%s/segment_net_%d.pth" % (save_path_str, epoch))
200+
print("save weights ! epoch = %d"%epoch)
201+
segment_net.train()
188202
pass

0 commit comments

Comments
 (0)