Skip to content

Commit d4c7cb4

Browse files
committed
working code committed
1 parent b314d66 commit d4c7cb4

File tree

4 files changed

+245
-234
lines changed

4 files changed

+245
-234
lines changed

eval.py

Lines changed: 98 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,99 @@
1-
import os
2-
import argparse
3-
4-
import cv2
5-
import numpy as np
6-
7-
import torch
8-
from torch.utils.data import DataLoader
9-
10-
from siamese import SiameseNetwork
11-
from libs.dataset import Dataset
12-
13-
if __name__ == "__main__":
14-
parser = argparse.ArgumentParser()
15-
16-
parser.add_argument(
17-
'--val_path',
18-
type=str,
19-
help="Path to directory containing validation dataset.",
20-
default="../dataset/test"
21-
)
22-
parser.add_argument(
23-
'-o',
24-
'--out_path',
25-
type=str,
26-
help="Path for outputting model weights and tensorboard summary.",
27-
default="output/images"
28-
)
29-
parser.add_argument(
30-
'-c',
31-
'--checkpoint',
32-
type=str,
33-
help="Path to model to be used for inference.",
34-
default="output/epoch_200.pth"
35-
)
36-
37-
args = parser.parse_args()
38-
39-
os.makedirs(args.out_path, exist_ok=True)
40-
41-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42-
43-
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False, testing=True)
44-
val_dataloader = DataLoader(val_dataset, batch_size=1)
45-
46-
model = SiameseNetwork()
47-
model.to(device)
48-
49-
checkpoint = torch.load(args.checkpoint)
50-
model.load_state_dict(checkpoint['model_state_dict'])
51-
52-
model.eval()
53-
54-
losses = []
55-
correct = 0
56-
total = 0
57-
58-
inv_transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
59-
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
60-
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
61-
std = [ 1., 1., 1. ]),
62-
])
63-
64-
for i, ((img1, img2), y, (class1, class2)) in enumerate(val_dataloader):
65-
print("[{} / {}]".format(i, len(val_dataloader)))
66-
67-
img1, img2, y = map(lambda x: x.to(device), [img1, img2, y])
68-
69-
prob = model(img1, img2)
70-
loss = criterion(prob, y)
71-
72-
losses.append(loss.item())
73-
correct += torch.count_nonzero(y == (prob > 0.5)).item()
74-
total += len(y)
75-
76-
fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2))
77-
plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob[0], class2))
78-
79-
# show first image
80-
ax = fig.add_subplot(1, 2, 1)
81-
plt.imshow(inv_transform(img1[0]), cmap=plt.cm.gray)
82-
plt.axis("off")
83-
84-
# show the second image
85-
ax = fig.add_subplot(1, 2, 2)
86-
plt.imshow(inv_transform(img2[0]), cmap=plt.cm.gray)
87-
plt.axis("off")
88-
89-
# show the plot
90-
plt.savefig(os.path.join(args.checkpoint, 'images/{}.png').format(i))
91-
1+
import os
2+
import argparse
3+
4+
import cv2
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
import torch
9+
from torch.utils.data import DataLoader
10+
from torchvision import transforms
11+
12+
from siamese import SiameseNetwork
13+
from libs.dataset import Dataset
14+
15+
if __name__ == "__main__":
16+
parser = argparse.ArgumentParser()
17+
18+
parser.add_argument(
19+
'--val_path',
20+
type=str,
21+
help="Path to directory containing validation dataset.",
22+
default="../dataset/test"
23+
)
24+
parser.add_argument(
25+
'-o',
26+
'--out_path',
27+
type=str,
28+
help="Path for outputting model weights and tensorboard summary.",
29+
default="output/images"
30+
)
31+
parser.add_argument(
32+
'-c',
33+
'--checkpoint',
34+
type=str,
35+
help="Path to model to be used for inference.",
36+
default="output/epoch_200.pth"
37+
)
38+
39+
args = parser.parse_args()
40+
41+
os.makedirs(args.out_path, exist_ok=True)
42+
43+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44+
45+
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False, testing=True)
46+
val_dataloader = DataLoader(val_dataset, batch_size=1)
47+
48+
model = SiameseNetwork()
49+
model.to(device)
50+
criterion = torch.nn.BCELoss()
51+
52+
checkpoint = torch.load(args.checkpoint)
53+
model.load_state_dict(checkpoint['model_state_dict'])
54+
55+
model.eval()
56+
57+
losses = []
58+
correct = 0
59+
total = 0
60+
61+
inv_transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
62+
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
63+
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
64+
std = [ 1., 1., 1. ]),
65+
])
66+
67+
for i, ((img1, img2), y, (class1, class2)) in enumerate(val_dataloader):
68+
print("[{} / {}]".format(i, len(val_dataloader)))
69+
70+
img1, img2, y = map(lambda x: x.to(device), [img1, img2, y])
71+
class1 = class1[0]
72+
class2 = class2[0]
73+
74+
prob = model(img1, img2)
75+
loss = criterion(prob, y)
76+
77+
losses.append(loss.item())
78+
correct += torch.count_nonzero(y == (prob > 0.5)).item()
79+
total += len(y)
80+
81+
fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2))
82+
plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob[0][0].item(), class2))
83+
84+
img1 = inv_transform(img1).cpu().numpy()[0]
85+
img2 = inv_transform(img2).cpu().numpy()[0]
86+
# show first image
87+
ax = fig.add_subplot(1, 2, 1)
88+
plt.imshow(img1[0], cmap=plt.cm.gray)
89+
plt.axis("off")
90+
91+
# show the second image
92+
ax = fig.add_subplot(1, 2, 2)
93+
plt.imshow(img2[0], cmap=plt.cm.gray)
94+
plt.axis("off")
95+
96+
# show the plot
97+
plt.savefig(os.path.join(args.out_path, '{}.png').format(i))
98+
9299
print("Validation: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total))

libs/dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def create_pairs(self):
5656
if self.shuffle_pairs:
5757
np.random.seed(int(time.time()))
5858
np.random.shuffle(self.indices1)
59+
elif self.testing:
60+
np.random.seed(int(time.time()))
5961
else:
6062
np.random.seed(1)
6163

@@ -92,12 +94,12 @@ def __iter__(self):
9294
image1 = self.transform(image1).float()
9395
image2 = self.transform(image2).float()
9496

95-
plt.imshow(image1[0])
96-
plt.imshow(image2[0])
97-
plt.show()
97+
# plt.imshow(image1[0])
98+
# plt.imshow(image2[0])
99+
# plt.show()
98100

99101
if self.testing:
100-
yield (image1, image2), torch.FloatTensor([class1==class2]), (image_path1, image_path2)
102+
yield (image1, image2), torch.FloatTensor([class1==class2]), (class1, class2)
101103
else:
102104
yield (image1, image2), torch.FloatTensor([class1==class2])
103105

siamese/siamese_network.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@ def __init__(self, backbone="resnet18"):
1515
out_features = list(self.backbone.modules())[-1].out_features
1616

1717
self.cls_head = nn.Sequential(
18-
# nn.Dropout(p=0.5),
19-
nn.Linear(out_features * 2, 512),
18+
nn.Dropout(p=0.5),
19+
nn.Linear(out_features, 512),
2020
nn.BatchNorm1d(512),
21-
nn.Sigmoid(),
21+
nn.ReLU(),
2222

23-
# nn.Dropout(p=0.5),
23+
nn.Dropout(p=0.5),
2424
nn.Linear(512, 64),
2525
nn.BatchNorm1d(64),
2626
nn.Sigmoid(),
27+
nn.Dropout(),
2728

2829
nn.Linear(64, 1),
2930
nn.Sigmoid(),
@@ -33,7 +34,7 @@ def forward(self, img1, img2):
3334
feat1 = self.backbone(img1)
3435
feat2 = self.backbone(img2)
3536

36-
combined_features = torch.cat((feat1, feat2), dim=-1)
37+
combined_features = feat1 * feat2
3738

3839
output = self.cls_head(combined_features)
3940
return output

0 commit comments

Comments
 (0)