1
- import os
2
- import argparse
3
-
4
- import cv2
5
- import numpy as np
6
-
7
- import torch
8
- from torch .utils .data import DataLoader
9
-
10
- from siamese import SiameseNetwork
11
- from libs .dataset import Dataset
12
-
13
- if __name__ == "__main__" :
14
- parser = argparse .ArgumentParser ()
15
-
16
- parser .add_argument (
17
- '--val_path' ,
18
- type = str ,
19
- help = "Path to directory containing validation dataset." ,
20
- default = "../dataset/test"
21
- )
22
- parser .add_argument (
23
- '-o' ,
24
- '--out_path' ,
25
- type = str ,
26
- help = "Path for outputting model weights and tensorboard summary." ,
27
- default = "output/images"
28
- )
29
- parser .add_argument (
30
- '-c' ,
31
- '--checkpoint' ,
32
- type = str ,
33
- help = "Path to model to be used for inference." ,
34
- default = "output/epoch_200.pth"
35
- )
36
-
37
- args = parser .parse_args ()
38
-
39
- os .makedirs (args .out_path , exist_ok = True )
40
-
41
- device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
42
-
43
- val_dataset = Dataset (args .val_path , shuffle_pairs = False , augment = False , testing = True )
44
- val_dataloader = DataLoader (val_dataset , batch_size = 1 )
45
-
46
- model = SiameseNetwork ()
47
- model .to (device )
48
-
49
- checkpoint = torch .load (args .checkpoint )
50
- model .load_state_dict (checkpoint ['model_state_dict' ])
51
-
52
- model .eval ()
53
-
54
- losses = []
55
- correct = 0
56
- total = 0
57
-
58
- inv_transform = transforms .Compose ([ transforms .Normalize (mean = [ 0. , 0. , 0. ],
59
- std = [ 1 / 0.229 , 1 / 0.224 , 1 / 0.225 ]),
60
- transforms .Normalize (mean = [ - 0.485 , - 0.456 , - 0.406 ],
61
- std = [ 1. , 1. , 1. ]),
62
- ])
63
-
64
- for i , ((img1 , img2 ), y , (class1 , class2 )) in enumerate (val_dataloader ):
65
- print ("[{} / {}]" .format (i , len (val_dataloader )))
66
-
67
- img1 , img2 , y = map (lambda x : x .to (device ), [img1 , img2 , y ])
68
-
69
- prob = model (img1 , img2 )
70
- loss = criterion (prob , y )
71
-
72
- losses .append (loss .item ())
73
- correct += torch .count_nonzero (y == (prob > 0.5 )).item ()
74
- total += len (y )
75
-
76
- fig = plt .figure ("class1={}\t class2={}" .format (class1 , class2 ), figsize = (4 , 2 ))
77
- plt .suptitle ("cls1={} conf={:.2f} cls2={}" .format (class1 , prob [0 ], class2 ))
78
-
79
- # show first image
80
- ax = fig .add_subplot (1 , 2 , 1 )
81
- plt .imshow (inv_transform (img1 [0 ]), cmap = plt .cm .gray )
82
- plt .axis ("off" )
83
-
84
- # show the second image
85
- ax = fig .add_subplot (1 , 2 , 2 )
86
- plt .imshow (inv_transform (img2 [0 ]), cmap = plt .cm .gray )
87
- plt .axis ("off" )
88
-
89
- # show the plot
90
- plt .savefig (os .path .join (args .checkpoint , 'images/{}.png' ).format (i ))
91
-
1
+ import os
2
+ import argparse
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import matplotlib .pyplot as plt
7
+
8
+ import torch
9
+ from torch .utils .data import DataLoader
10
+ from torchvision import transforms
11
+
12
+ from siamese import SiameseNetwork
13
+ from libs .dataset import Dataset
14
+
15
+ if __name__ == "__main__" :
16
+ parser = argparse .ArgumentParser ()
17
+
18
+ parser .add_argument (
19
+ '--val_path' ,
20
+ type = str ,
21
+ help = "Path to directory containing validation dataset." ,
22
+ default = "../dataset/test"
23
+ )
24
+ parser .add_argument (
25
+ '-o' ,
26
+ '--out_path' ,
27
+ type = str ,
28
+ help = "Path for outputting model weights and tensorboard summary." ,
29
+ default = "output/images"
30
+ )
31
+ parser .add_argument (
32
+ '-c' ,
33
+ '--checkpoint' ,
34
+ type = str ,
35
+ help = "Path to model to be used for inference." ,
36
+ default = "output/epoch_200.pth"
37
+ )
38
+
39
+ args = parser .parse_args ()
40
+
41
+ os .makedirs (args .out_path , exist_ok = True )
42
+
43
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
44
+
45
+ val_dataset = Dataset (args .val_path , shuffle_pairs = False , augment = False , testing = True )
46
+ val_dataloader = DataLoader (val_dataset , batch_size = 1 )
47
+
48
+ model = SiameseNetwork ()
49
+ model .to (device )
50
+ criterion = torch .nn .BCELoss ()
51
+
52
+ checkpoint = torch .load (args .checkpoint )
53
+ model .load_state_dict (checkpoint ['model_state_dict' ])
54
+
55
+ model .eval ()
56
+
57
+ losses = []
58
+ correct = 0
59
+ total = 0
60
+
61
+ inv_transform = transforms .Compose ([ transforms .Normalize (mean = [ 0. , 0. , 0. ],
62
+ std = [ 1 / 0.229 , 1 / 0.224 , 1 / 0.225 ]),
63
+ transforms .Normalize (mean = [ - 0.485 , - 0.456 , - 0.406 ],
64
+ std = [ 1. , 1. , 1. ]),
65
+ ])
66
+
67
+ for i , ((img1 , img2 ), y , (class1 , class2 )) in enumerate (val_dataloader ):
68
+ print ("[{} / {}]" .format (i , len (val_dataloader )))
69
+
70
+ img1 , img2 , y = map (lambda x : x .to (device ), [img1 , img2 , y ])
71
+ class1 = class1 [0 ]
72
+ class2 = class2 [0 ]
73
+
74
+ prob = model (img1 , img2 )
75
+ loss = criterion (prob , y )
76
+
77
+ losses .append (loss .item ())
78
+ correct += torch .count_nonzero (y == (prob > 0.5 )).item ()
79
+ total += len (y )
80
+
81
+ fig = plt .figure ("class1={}\t class2={}" .format (class1 , class2 ), figsize = (4 , 2 ))
82
+ plt .suptitle ("cls1={} conf={:.2f} cls2={}" .format (class1 , prob [0 ][0 ].item (), class2 ))
83
+
84
+ img1 = inv_transform (img1 ).cpu ().numpy ()[0 ]
85
+ img2 = inv_transform (img2 ).cpu ().numpy ()[0 ]
86
+ # show first image
87
+ ax = fig .add_subplot (1 , 2 , 1 )
88
+ plt .imshow (img1 [0 ], cmap = plt .cm .gray )
89
+ plt .axis ("off" )
90
+
91
+ # show the second image
92
+ ax = fig .add_subplot (1 , 2 , 2 )
93
+ plt .imshow (img2 [0 ], cmap = plt .cm .gray )
94
+ plt .axis ("off" )
95
+
96
+ # show the plot
97
+ plt .savefig (os .path .join (args .out_path , '{}.png' ).format (i ))
98
+
92
99
print ("Validation: Loss={:.2f}\t Accuracy={:.2f}\t " .format (sum (losses )/ len (losses ), correct / total ))
0 commit comments