|
27 | 27 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
|
28 | 28 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
|
29 | 29 |
|
30 |
| -parser.add_argument("--begin_epoch", type=int, default=0, help="number of gpu") |
31 |
| -parser.add_argument("--end_epoch", type=int, default=31, help="number of gpu") |
| 30 | +parser.add_argument("--begin_epoch", type=int, default=0, help="begin_epoch") |
| 31 | +parser.add_argument("--end_epoch", type=int, default=51, help="end_epoch") |
| 32 | + |
| 33 | +parser.add_argument("--need_test", type=bool, default=True, help="need to test") |
32 | 34 | parser.add_argument("--test_interval", type=int, default=10, help="interval of test")
|
33 |
| -parser.add_argument("--save_interval", type=int, default=10, help="interval of test") |
| 35 | +parser.add_argument("--need_save", type=bool, default=True, help="need to save") |
| 36 | +parser.add_argument("--save_interval", type=int, default=10, help="interval of save weights") |
| 37 | + |
34 | 38 |
|
35 | 39 | parser.add_argument("--img_height", type=int, default=704, help="size of image height")
|
36 | 40 | parser.add_argument("--img_width", type=int, default=256, help="size of image width")
|
|
52 | 56 |
|
53 | 57 | # Build nets
|
54 | 58 | segment_net = SegmentNet(init_weights=True)
|
55 |
| -decision_net = DecisionNet(init_weights=True) |
| 59 | +#decision_net = DecisionNet(init_weights=True) |
56 | 60 |
|
57 | 61 | # Loss functions
|
58 | 62 | criterion_segment = torch.nn.MSELoss()
|
59 |
| -criterion_decision = torch.nn.L1Loss() |
| 63 | +#criterion_decision = torch.nn.L1Loss() |
60 | 64 |
|
61 | 65 | if opt.cuda:
|
62 | 66 | segment_net = segment_net.cuda()
|
63 |
| - decision_net = decision_net.cuda() |
| 67 | + #decision_net = decision_net.cuda() |
64 | 68 | criterion_segment.cuda()
|
65 |
| - criterion_decision.cuda() |
| 69 | + #criterion_decision.cuda() |
66 | 70 |
|
67 | 71 | if opt.gpu_num > 1:
|
68 | 72 | segment_net = torch.nn.DataParallel(segment_net, device_ids=list(range(opt.gpu_num)))
|
69 |
| - decision_net = torch.nn.DataParallel(decision_net, device_ids=list(range(opt.gpu_num))) |
| 73 | + # decision_net = torch.nn.DataParallel(decision_net, device_ids=list(range(opt.gpu_num))) |
70 | 74 |
|
71 | 75 | if opt.begin_epoch != 0:
|
72 | 76 | # Load pretrained models
|
73 |
| - segment_net.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch))) |
74 |
| - decision_net.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch))) |
| 77 | + segment_net.load_state_dict(torch.load("./saved_models/segment_net_%d.pth" % (opt.begin_epoch))) |
| 78 | + # decision_net.load_state_dict(torch.load("./saved_models/decision_net_%d.pth" % (opt.begin_epoch))) |
75 | 79 | else:
|
76 | 80 | # Initialize weights
|
77 | 81 | segment_net.apply(weights_init_normal)
|
78 |
| - decision_net.apply(weights_init_normal) |
| 82 | + # decision_net.apply(weights_init_normal) |
79 | 83 |
|
80 | 84 | # Optimizers
|
81 | 85 | optimizer_seg = torch.optim.Adam(segment_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
82 |
| -optimizer_dec = torch.optim.Adam(decision_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) |
| 86 | +#optimizer_dec = torch.optim.Adam(decision_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) |
83 | 87 |
|
84 | 88 | transforms_ = transforms.Compose([
|
85 | 89 | transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
|
|
125 | 129 | lenNum = 2*(lenNum-1)
|
126 | 130 |
|
127 | 131 | segment_net.train()
|
128 |
| - |
| 132 | + # train ***************************************************************** |
129 | 133 | for i in range(0, lenNum):
|
130 | 134 | if i % 2 == 0:
|
131 | 135 | batchData = iterOK.__next__()
|
|
161 | 165 | loss_seg.item()
|
162 | 166 | )
|
163 | 167 | )
|
164 |
| - |
165 |
| - if epoch % opt.test_interval == 0 and epoch >= opt.test_interval: |
| 168 | + |
| 169 | + # test **************************************************************************** |
| 170 | + if opt.need_test and epoch % opt.test_interval == 0 and epoch >= opt.test_interval: |
166 | 171 | segment_net.eval()
|
167 |
| - |
| 172 | + |
168 | 173 | for i, testBatch in enumerate(testloader):
|
169 | 174 | imgTest = testBatch["img"].cuda()
|
170 | 175 | t1 = time.time()
|
|
180 | 185 | print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
|
181 | 186 | save_image(imgTest.data, "%s/img_%d.jpg"% (save_path_str, i))
|
182 | 187 | save_image(segTest.data, "%s/img_%d_seg.jpg"% (save_path_str, i))
|
| 188 | + |
| 189 | + segment_net.train() |
| 190 | + |
| 191 | + # save parameters ***************************************************************** |
| 192 | + if opt.need_save and epoch % opt.save_interval == 0 and epoch >= opt.save_interval: |
| 193 | + segment_net.eval() |
183 | 194 |
|
| 195 | + save_path_str = "./saved_models" |
| 196 | + if os.path.exists(save_path_str) == False: |
| 197 | + os.makedirs(save_path_str, exist_ok=True) |
184 | 198 |
|
185 |
| - if epoch % opt.save_interval == 0 and epoch >= opt.save_interval: |
186 |
| - |
187 |
| - |
| 199 | + torch.save(segment_net.state_dict(), "%s/segment_net_%d.pth" % (save_path_str, epoch)) |
| 200 | + print("save weights ! epoch = %d"%epoch) |
| 201 | + segment_net.train() |
188 | 202 | pass
|
0 commit comments