Skip to content

Commit 7570c7e

Browse files
committed
allow changing backbone network through command line arguments. Add saving of model with best validation loss.
1 parent d4c7cb4 commit 7570c7e

File tree

3 files changed

+110
-14
lines changed

3 files changed

+110
-14
lines changed

README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Siamese Network
2+
3+
A simple but pragmatic implementation of Siamese Networks in PyTorch using the pre-trained feature extraction networks provided in ```torchvision.models```.
4+
5+
## Design Choices:
6+
- The siamese network provided in this repository uses a sigmoid at its output, thus making it a binary classification task (positive=same, negative=different) with binary cross entropy loss, as opposed to the triplet loss generally used.
7+
- I have added dropout to the final classification head network along-with BatchNorm. On online forums there is discussion that dropout with batchnorm is ineffective, however, I found it to improve the results on my specific private dataset.
8+
- Instead of concatenating the feature vectors of the two images, I opted to multiply them element-wise, which increased the validation accuracy for my specific dataset.
9+
10+
11+
## Setting up the dataset.
12+
The expected format for both the training and validation dataset is the same. Image belonging to a single entity/class should be placed in a folder with the name of the class. The folders for every class are then to be placed within a common root directory (which will be passed to the trainined and evaluation scripts). The folder structure is also explained below:
13+
```
14+
|--Train or Validation dataset root directory
15+
|--Class1
16+
|-Image1
17+
|-Image2
18+
.
19+
.
20+
.
21+
|-ImageN
22+
|--Class2
23+
|--Class3
24+
.
25+
.
26+
.
27+
|--ClassN
28+
```
29+
30+
## Training the model:
31+
To train the model, run the following command along with the required command line arguments:
32+
```shell
33+
python train.py [-h] --train_path TRAIN_PATH --val_path VAL_PATH -o OUT_PATH
34+
[-b BACKBONE] [-lr LEARNING_RATE] [-e EPOCHS] [-s SAVE_AFTER]
35+
36+
optional arguments:
37+
-h, --help show this help message and exit
38+
--train_path TRAIN_PATH
39+
Path to directory containing training dataset.
40+
--val_path VAL_PATH Path to directory containing validation dataset.
41+
-o OUT_PATH, --out_path OUT_PATH
42+
Path for outputting model weights and tensorboard
43+
summary.
44+
-b BACKBONE, --backbone BACKBONE
45+
Network backbone from torchvision.models to be used in
46+
the siamese network.
47+
-lr LEARNING_RATE, --learning_rate LEARNING_RATE
48+
Learning Rate
49+
-e EPOCHS, --epochs EPOCHS
50+
Number of epochs to train
51+
-s SAVE_AFTER, --save_after SAVE_AFTER
52+
Model checkpoint is saved after each specified number
53+
of epochs.
54+
```
55+
The backbone can be chosen from any of the networks listed in [torchvision.models](https://pytorch.org/vision/stable/models.html)
56+
57+
## Evaluating the model:
58+
Following command can be used to evaluate the model on a validation set. Output images with containing the pair and their corresponding similarity confidence will be outputted to `{OUT_PATH}`.
59+
60+
Note: During evaluation the pairs are generated with a deterministic seed for the numpy random module, so as to allow comparisons between multiple evaluations.
61+
62+
```shell
63+
python eval.py [-h] -v VAL_PATH -o OUT_PATH -c CHECKPOINT
64+
65+
optional arguments:
66+
-h, --help show this help message and exit
67+
-v VAL_PATH, --val_path VAL_PATH
68+
Path to directory containing validation dataset.
69+
-o OUT_PATH, --out_path OUT_PATH
70+
Path for saving prediction images.
71+
-c CHECKPOINT, --checkpoint CHECKPOINT
72+
Path of model checkpoint to be used for inference.
73+
```

eval.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
parser = argparse.ArgumentParser()
1717

1818
parser.add_argument(
19+
'-v',
1920
'--val_path',
2021
type=str,
2122
help="Path to directory containing validation dataset.",
22-
default="../dataset/test"
23+
required=True
2324
)
2425
parser.add_argument(
2526
'-o',
2627
'--out_path',
2728
type=str,
28-
help="Path for outputting model weights and tensorboard summary.",
29-
default="output/images"
29+
help="Path for saving prediction images.",
30+
required=True
3031
)
3132
parser.add_argument(
3233
'-c',
3334
'--checkpoint',
3435
type=str,
35-
help="Path to model to be used for inference.",
36-
default="output/epoch_200.pth"
36+
help="Path of model checkpoint to be used for inference.",
37+
required=True
3738
)
3839

3940
args = parser.parse_args()
@@ -45,13 +46,12 @@
4546
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False, testing=True)
4647
val_dataloader = DataLoader(val_dataset, batch_size=1)
4748

48-
model = SiameseNetwork()
49-
model.to(device)
5049
criterion = torch.nn.BCELoss()
5150

5251
checkpoint = torch.load(args.checkpoint)
52+
model = SiameseNetwork(backbone=checkpoint['backbone'])
53+
model.to(device)
5354
model.load_state_dict(checkpoint['model_state_dict'])
54-
5555
model.eval()
5656

5757
losses = []

train.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,27 @@
1919
'--train_path',
2020
type=str,
2121
help="Path to directory containing training dataset.",
22-
default="../dataset/train"
22+
required=True
2323
)
2424
parser.add_argument(
2525
'--val_path',
2626
type=str,
2727
help="Path to directory containing validation dataset.",
28-
default="../dataset/test"
28+
required=True
2929
)
3030
parser.add_argument(
3131
'-o',
3232
'--out_path',
3333
type=str,
3434
help="Path for outputting model weights and tensorboard summary.",
35-
default="output"
35+
required=True
36+
)
37+
parser.add_argument(
38+
'-b',
39+
'--backbone',
40+
type=str,
41+
help="Network backbone from torchvision.models to be used in the siamese network.",
42+
default="resnet18"
3643
)
3744
parser.add_argument(
3845
'-lr',
@@ -68,14 +75,16 @@
6875
train_dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True)
6976
val_dataloader = DataLoader(val_dataset, batch_size=8)
7077

71-
model = SiameseNetwork()
78+
model = SiameseNetwork(backbone=args.backbone)
7279
model.to(device)
7380

7481
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
7582
criterion = torch.nn.BCELoss()
7683

7784
writer = SummaryWriter(os.path.join(args.out_path, "summary"))
7885

86+
best_val = 10000000000
87+
7988
for epoch in range(args.epochs):
8089
print("[{} / {}]".format(epoch, args.epochs))
8190
model.train()
@@ -120,16 +129,30 @@
120129
correct += torch.count_nonzero(y == (prob > 0.5)).item()
121130
total += len(y)
122131

123-
writer.add_scalar('val_loss', sum(losses)/len(losses), epoch)
132+
val_loss = sum(losses)/max(1, len(losses))
133+
writer.add_scalar('val_loss', val_loss, epoch)
124134
writer.add_scalar('val_acc', correct / total, epoch)
125135

126-
print("\tValidation: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total))
136+
print("\tValidation: Loss={:.2f}\t Accuracy={:.2f}\t".format(val_loss, correct / total))
137+
138+
if val_loss < best_val:
139+
best_val = val_loss
140+
torch.save(
141+
{
142+
"epoch": epoch + 1,
143+
"model_state_dict": model.state_dict(),
144+
"backbone": args.backbone,
145+
"optimizer_state_dict": optimizer.state_dict()
146+
},
147+
os.path.join(args.out_path, "best.pth")
148+
)
127149

128150
if (epoch + 1) % args.save_after == 0:
129151
torch.save(
130152
{
131153
"epoch": epoch + 1,
132154
"model_state_dict": model.state_dict(),
155+
"backbone": args.backbone,
133156
"optimizer_state_dict": optimizer.state_dict()
134157
},
135158
os.path.join(args.out_path, "epoch_{}.pth".format(epoch + 1))

0 commit comments

Comments
 (0)