Skip to content

Commit 1753aca

Browse files
committed
little problem
1 parent 2ba52ca commit 1753aca

File tree

5 files changed

+364
-14
lines changed

5 files changed

+364
-14
lines changed

dataset.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from PIL import Image
88
import torchvision.transforms as transforms
99

10+
import torchvision.transforms.functional as VF
11+
1012
class SegDataset(Dataset):
1113
def __init__(self, dataRoot, transforms_= None, transforms_mask = None, subFold="Train_NG", isTrain=True):
1214

@@ -32,18 +34,33 @@ def __getitem__(self, index):
3234

3335
idx = index % self.len
3436

35-
img = Image.open(self.imgFiles[idx]).convert("RGB")
36-
img = self.transform(img)
37+
3738

3839
if self.isTrain==True:
40+
41+
img = Image.open(self.imgFiles[idx]).convert("RGB")
42+
3943
#mask = Image.open(self.labelFiles[idx]).convert("RGB")
4044
mat = cv2.imread(self.labelFiles[idx], cv2.IMREAD_GRAYSCALE)
4145
kernel = np.ones((5, 5), np.uint8)
4246
matD = cv2.dilate(mat, kernel)
4347
mask = Image.fromarray(matD) # image2 is a PIL image
48+
49+
if np.random.rand(1) > 0.5:
50+
mask = VF.hflip(mask)
51+
img = VF.hflip(img)
52+
53+
if np.random.rand(1) > 0.5:
54+
mask = VF.vflip(mask)
55+
img = VF.vflip(img)
56+
57+
img = self.transform(img)
4458
mask = self.maskTransform(mask)
59+
4560
return {"img":img, "mask":mask}
4661
else:
62+
img = Image.open(self.imgFiles[idx]).convert("RGB")
63+
img = self.transform(img)
4764
return {"img":img}
4865

4966
def __len__(self):

models.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,29 +88,29 @@ def __init__(self, init_weights=True):
8888

8989
self.layer1 = nn.Sequential(
9090
nn.MaxPool2d(2),
91-
nn.Conv2d(1025, 8, 5),
91+
nn.Conv2d(1025, 8, 5, stride=1, padding=2),
9292
nn.BatchNorm2d(8),
9393
nn.ReLU(inplace=True),
9494
nn.MaxPool2d(2),
95-
nn.Conv2d(8, 16, 5),
95+
nn.Conv2d(8, 16, 5, stride=1, padding=2),
9696
nn.BatchNorm2d(16),
9797
nn.ReLU(inplace=True),
98-
nn.Conv2d(16, 32, 5),
98+
nn.Conv2d(16, 32, 5, stride=1, padding=2),
9999
nn.BatchNorm2d(32),
100100
nn.ReLU(inplace=True)
101101
)
102102

103103
self.fc = nn.Sequential(
104-
nn.Linear(66, 1, bias=False)
104+
nn.Linear(66, 1, bias=False),
105+
nn.Sigmoid()
105106
)
106107

107108
if init_weights == True:
108109
pass
109110

110111
def forward(self, f, s):
111-
x = torch.cat((f, s), 1)
112-
x1 = self.layer1(x)
113-
112+
xx = torch.cat((f, s), 1)
113+
x1 = self.layer1(xx)
114114
x2 = x1.view(x1.size(0), x1.size(1), -1)
115115
s2 = s.view(s.size(0), s.size(1), -1)
116116

@@ -129,12 +129,20 @@ def forward(self, f, s):
129129

130130
snet = SegmentNet()
131131
dnet = DecisionNet()
132+
img = torch.randn(4, 3, 704, 256)
133+
134+
snet.eval()
132135

133-
img = torch.randn(4, 3, 512, 512)
136+
snet = snet.cuda()
137+
dnet = dnet.cuda()
138+
img = img.cuda()
134139

135-
f,s = snet.forward(img)
136-
c = dnet.forward(f,s)
140+
ret = snet(img)
141+
f = ret["f"]
142+
s = ret["seg"]
137143

144+
c = dnet(f, s)
145+
print(c)
138146
pass
139147

140148

test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from models import SegmentNet, DecisionNet, weights_init_normal
2+
from dataset import SegDataset
3+
4+
import torch.nn as nn
5+
import torch
6+
7+
from torchvision import datasets
8+
from torchvision.utils import save_image
9+
import torchvision.transforms as transforms
10+
from torch.autograd import Variable
11+
from torch.utils.data import DataLoader
12+
13+
import os
14+
import sys
15+
import argparse
16+
import time
17+
import PIL.Image as Image
18+
19+
parser = argparse.ArgumentParser()
20+
21+
parser.add_argument("--cuda", type=bool, default=True, help="number of gpu")
22+
parser.add_argument("--test_seg_epoch", type=int, default=60, help="test segment epoch")
23+
parser.add_argument("--test_dec_epoch", type=int, default=50, help="test segment epoch")
24+
parser.add_argument("--img_height", type=int, default=704, help="size of image height")
25+
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
26+
27+
opt = parser.parse_args()
28+
29+
print(opt)
30+
31+
dataSetRoot = "/home/sean/Data/KolektorSDD_sean"
32+
33+
# ***********************************************************************
34+
35+
# Build nets
36+
segment_net = SegmentNet(init_weights=True)
37+
decision_net = DecisionNet(init_weights=True)
38+
39+
if opt.cuda:
40+
segment_net = segment_net.cuda()
41+
decision_net = decision_net.cuda()
42+
43+
if opt.test_seg_epoch != 0:
44+
# Load pretrained models
45+
segment_net.load_state_dict(torch.load("./saved_models/segment_net_%d.pth" % (opt.test_seg_epoch)))
46+
47+
if opt.test_dec_epoch != 0:
48+
# Load pretrained models
49+
decision_net.load_state_dict(torch.load("./saved_models/decision_net_%d.pth" % (opt.test_dec_epoch)))
50+
51+
transforms_ = transforms.Compose([
52+
transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
53+
transforms.ToTensor(),
54+
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
55+
])
56+
57+
58+
testloader = DataLoader(
59+
SegDataset(dataSetRoot, transforms_=transforms_, transforms_mask= None, subFold="Test", isTrain=False),
60+
batch_size=1,
61+
shuffle=False,
62+
num_workers=0,
63+
)
64+
65+
segment_net.eval()
66+
decision_net.eval()
67+
68+
for i, testBatch in enumerate(testloader):
69+
t1 = time.time()
70+
imgTest = testBatch["img"].cuda()
71+
rstTest = segment_net(imgTest)
72+
73+
fTest = rstTest["f"]
74+
segTest = rstTest["seg"]
75+
76+
cTest = decision_net(fTest, segTest)
77+
78+
t2 = time.time()
79+
80+
if cTest.item() > 0.5:
81+
labelStr = "NG"
82+
else:
83+
labelStr = "OK"
84+
85+
save_path_str = os.path.join(dataSetRoot, "testResult")
86+
87+
if os.path.exists(save_path_str) == False:
88+
os.makedirs(save_path_str, exist_ok=True)
89+
90+
print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
91+
save_image(imgTest.data, "%s/img_%d_%s.jpg"% (save_path_str, i, labelStr))
92+
save_image(segTest.data, "%s/img_%d_seg_%s.jpg"% (save_path_str, i, labelStr))

0 commit comments

Comments
 (0)