Skip to content

Commit 9339723

Browse files
committed
add docstring and comments. remove redundant flag from dataset.
1 parent b1fea92 commit 9339723

File tree

5 files changed

+67
-36
lines changed

5 files changed

+67
-36
lines changed

eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141

4242
os.makedirs(args.out_path, exist_ok=True)
4343

44+
# Set device to CUDA if a CUDA device is available, else CPU
4445
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
4546

46-
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False, testing=True)
47+
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False)
4748
val_dataloader = DataLoader(val_dataset, batch_size=1)
4849

4950
criterion = torch.nn.BCELoss()
@@ -81,6 +82,7 @@
8182
fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2))
8283
plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob[0][0].item(), class2))
8384

85+
# Apply inverse transform (denormalization) on the images to retrieve original images.
8486
img1 = inv_transform(img1).cpu().numpy()[0]
8587
img2 = inv_transform(img2).cpu().numpy()[0]
8688
# show first image

libs/dataset.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,30 @@
1212
from torchvision import transforms
1313

1414
class Dataset(torch.utils.data.IterableDataset):
15-
def __init__(self, path, shuffle_pairs=True, augment=False, testing=False):
15+
def __init__(self, path, shuffle_pairs=True, augment=False):
16+
'''
17+
Create an iterable dataset from a directory containing sub-directories of
18+
entities with their images contained inside each sub-directory.
19+
20+
Parameters:
21+
path (str): Path to directory containing the dataset.
22+
shuffle_pairs (boolean): Pass True when training, False otherwise. When set to false, the image pair generation will be deterministic
23+
augment (boolean): When True, images will be augmented using a standard set of transformations.
24+
25+
where b = batch size
26+
27+
Returns:
28+
output (torch.Tensor): shape=[b, 1], Similarity of each pair of images
29+
'''
1630
self.path = path
1731

1832
self.feed_shape = [3, 224, 224]
1933
self.shuffle_pairs = shuffle_pairs
20-
self.testing = testing
2134

2235
self.augment = augment
2336

2437
if self.augment:
38+
# If images are to be augmented, add extra operations for it (first two).
2539
self.transform = transforms.Compose([
2640
transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=0.2),
2741
transforms.RandomHorizontalFlip(p=0.5),
@@ -30,6 +44,7 @@ def __init__(self, path, shuffle_pairs=True, augment=False, testing=False):
3044
transforms.Resize(self.feed_shape[1:])
3145
])
3246
else:
47+
# If no augmentation is needed then apply only the normalization and resizing operations.
3348
self.transform = transforms.Compose([
3449
transforms.ToTensor(),
3550
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -39,6 +54,10 @@ def __init__(self, path, shuffle_pairs=True, augment=False, testing=False):
3954
self.create_pairs()
4055

4156
def create_pairs(self):
57+
'''
58+
Creates two lists of indices that will form the pairs, to be fed for training or evaluation.
59+
'''
60+
4261
self.image_paths = glob.glob(os.path.join(self.path, "*/*.png"))
4362
self.image_classes = []
4463
self.class_indices = {}
@@ -56,9 +75,8 @@ def create_pairs(self):
5675
if self.shuffle_pairs:
5776
np.random.seed(int(time.time()))
5877
np.random.shuffle(self.indices1)
59-
# elif self.testing:
60-
# np.random.seed(int(time.time()))
6178
else:
79+
# If shuffling is set to off, set the random seed to 1, to make it deterministic.
6280
np.random.seed(1)
6381

6482
select_pos_pair = np.random.rand(len(self.image_paths)) < 0.5
@@ -79,7 +97,6 @@ def __iter__(self):
7997
self.create_pairs()
8098

8199
for idx, idx2 in zip(self.indices1, self.indices2):
82-
# idx2 = self.indices_pairs[idx]
83100

84101
image_path1 = self.image_paths[idx]
85102
image_path2 = self.image_paths[idx2]
@@ -94,14 +111,7 @@ def __iter__(self):
94111
image1 = self.transform(image1).float()
95112
image2 = self.transform(image2).float()
96113

97-
# plt.imshow(image1[0])
98-
# plt.imshow(image2[0])
99-
# plt.show()
100-
101-
if self.testing:
102-
yield (image1, image2), torch.FloatTensor([class1==class2]), (class1, class2)
103-
else:
104-
yield (image1, image2), torch.FloatTensor([class1==class2])
114+
yield (image1, image2), torch.FloatTensor([class1==class2]), (class1, class2)
105115

106116
def __len__(self):
107117
return len(self.image_paths)

libs/plot_training.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

siamese/siamese_network.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,26 @@
66

77
class SiameseNetwork(nn.Module):
88
def __init__(self, backbone="resnet18"):
9+
'''
10+
Creates a siamese network with a network from torchvision.models as backbone.
11+
12+
Parameters:
13+
backbone (str): Options of the backbone networks can be found at https://pytorch.org/vision/stable/models.html
14+
'''
15+
916
super().__init__()
1017

1118
if backbone not in models.__dict__:
1219
raise Exception("No model named {} exists in torchvision.models.".format(backbone))
1320

21+
# Create a backbone network from the pretrained models provided in torchvision.models
1422
self.backbone = models.__dict__[backbone](pretrained=True, progress=True)
23+
24+
# Get the number of features that are outputted by the last layer of backbone network.
1525
out_features = list(self.backbone.modules())[-1].out_features
1626

27+
# Create an MLP (multi-layer perceptron) as the classification head.
28+
# Classifies if provided combined feature vector of the 2 images represent same player or different.
1729
self.cls_head = nn.Sequential(
1830
nn.Dropout(p=0.5),
1931
nn.Linear(out_features, 512),
@@ -24,17 +36,34 @@ def __init__(self, backbone="resnet18"):
2436
nn.Linear(512, 64),
2537
nn.BatchNorm1d(64),
2638
nn.Sigmoid(),
27-
nn.Dropout(),
39+
nn.Dropout(p=0.5),
2840

2941
nn.Linear(64, 1),
3042
nn.Sigmoid(),
3143
)
3244

3345
def forward(self, img1, img2):
46+
'''
47+
Returns the similarity value between two images.
48+
49+
Parameters:
50+
img1 (torch.Tensor): shape=[b, 3, 224, 224]
51+
img2 (torch.Tensor): shape=[b, 3, 224, 224]
52+
53+
where b = batch size
54+
55+
Returns:
56+
output (torch.Tensor): shape=[b, 1], Similarity of each pair of images
57+
'''
58+
59+
# Pass the both images through the backbone network to get their seperate feature vectors
3460
feat1 = self.backbone(img1)
3561
feat2 = self.backbone(img2)
3662

63+
# Multiply (element-wise) the feature vectors of the two images together,
64+
# to generate a combined feature vector representing the similarity between the two.
3765
combined_features = feat1 * feat2
3866

67+
# Pass the combined feature vector through classification head to get similarity value in the range of 0 to 1.
3968
output = self.cls_head(combined_features)
4069
return output

train.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@
6767

6868
os.makedirs(args.out_path, exist_ok=True)
6969

70+
# Set device to CUDA if a CUDA device is available, else CPU
7071
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7172

72-
train_dataset = Dataset(args.train_path, shuffle_pairs=True, augment=True, testing=False)
73-
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False, testing=False)
73+
train_dataset = Dataset(args.train_path, shuffle_pairs=True, augment=True)
74+
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False)
7475

7576
train_dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True)
7677
val_dataloader = DataLoader(val_dataset, batch_size=8)
@@ -93,8 +94,8 @@
9394
correct = 0
9495
total = 0
9596

96-
# pbar = tqdm()
97-
for (img1, img2), y in train_dataloader:
97+
# Training Loop Start
98+
for (img1, img2), y, (class1, class2) in train_dataloader:
9899
img1, img2, y = map(lambda x: x.to(device), [img1, img2, y])
99100

100101
prob = model(img1, img2)
@@ -112,14 +113,16 @@
112113
writer.add_scalar('train_acc', correct / total, epoch)
113114

114115
print("\tTraining: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total))
116+
# Training Loop End
115117

118+
# Evaluation Loop Start
116119
model.eval()
117120

118121
losses = []
119122
correct = 0
120123
total = 0
121124

122-
for (img1, img2), y in val_dataloader:
125+
for (img1, img2), y, (class1, class2) in val_dataloader:
123126
img1, img2, y = map(lambda x: x.to(device), [img1, img2, y])
124127

125128
prob = model(img1, img2)
@@ -134,7 +137,9 @@
134137
writer.add_scalar('val_acc', correct / total, epoch)
135138

136139
print("\tValidation: Loss={:.2f}\t Accuracy={:.2f}\t".format(val_loss, correct / total))
140+
# Evaluation Loop End
137141

142+
# Update "best.pth" model if val_loss in current epoch is lower than the best validation loss
138143
if val_loss < best_val:
139144
best_val = val_loss
140145
torch.save(
@@ -147,6 +152,7 @@
147152
os.path.join(args.out_path, "best.pth")
148153
)
149154

155+
# Save model based on the frequency defined by "args.save_after"
150156
if (epoch + 1) % args.save_after == 0:
151157
torch.save(
152158
{

0 commit comments

Comments
 (0)