|
| 1 | +# Siamese Network |
| 2 | + |
| 3 | +A simple but pragmatic implementation of Siamese Networks in PyTorch using the pre-trained feature extraction networks provided in ```torchvision.models```. |
| 4 | + |
| 5 | +## Design Choices: |
| 6 | +- The siamese network provided in this repository uses a sigmoid at its output, thus making it a binary classification task (positive=same, negative=different) with binary cross entropy loss, as opposed to the triplet loss generally used. |
| 7 | +- I have added dropout to the final classification head network along-with BatchNorm. On online forums there is discussion that dropout with batchnorm is ineffective, however, I found it to improve the results on my specific private dataset. |
| 8 | +- Instead of concatenating the feature vectors of the two images, I opted to multiply them element-wise, which increased the validation accuracy for my specific dataset. |
| 9 | + |
| 10 | + |
| 11 | +## Setting up the dataset. |
| 12 | +The expected format for both the training and validation dataset is the same. Image belonging to a single entity/class should be placed in a folder with the name of the class. The folders for every class are then to be placed within a common root directory (which will be passed to the trainined and evaluation scripts). The folder structure is also explained below: |
| 13 | +``` |
| 14 | +|--Train or Validation dataset root directory |
| 15 | + |--Class1 |
| 16 | + |-Image1 |
| 17 | + |-Image2 |
| 18 | + . |
| 19 | + . |
| 20 | + . |
| 21 | + |-ImageN |
| 22 | + |--Class2 |
| 23 | + |--Class3 |
| 24 | + . |
| 25 | + . |
| 26 | + . |
| 27 | + |--ClassN |
| 28 | +``` |
| 29 | + |
| 30 | +## Training the model: |
| 31 | +To train the model, run the following command along with the required command line arguments: |
| 32 | +```shell |
| 33 | +python train.py [-h] --train_path TRAIN_PATH --val_path VAL_PATH -o OUT_PATH |
| 34 | + [-b BACKBONE] [-lr LEARNING_RATE] [-e EPOCHS] [-s SAVE_AFTER] |
| 35 | + |
| 36 | +optional arguments: |
| 37 | + -h, --help show this help message and exit |
| 38 | + --train_path TRAIN_PATH |
| 39 | + Path to directory containing training dataset. |
| 40 | + --val_path VAL_PATH Path to directory containing validation dataset. |
| 41 | + -o OUT_PATH, --out_path OUT_PATH |
| 42 | + Path for outputting model weights and tensorboard |
| 43 | + summary. |
| 44 | + -b BACKBONE, --backbone BACKBONE |
| 45 | + Network backbone from torchvision.models to be used in |
| 46 | + the siamese network. |
| 47 | + -lr LEARNING_RATE, --learning_rate LEARNING_RATE |
| 48 | + Learning Rate |
| 49 | + -e EPOCHS, --epochs EPOCHS |
| 50 | + Number of epochs to train |
| 51 | + -s SAVE_AFTER, --save_after SAVE_AFTER |
| 52 | + Model checkpoint is saved after each specified number |
| 53 | + of epochs. |
| 54 | +``` |
| 55 | +The backbone can be chosen from any of the networks listed in [torchvision.models](https://pytorch.org/vision/stable/models.html) |
| 56 | + |
| 57 | +## Evaluating the model: |
| 58 | +Following command can be used to evaluate the model on a validation set. Output images with containing the pair and their corresponding similarity confidence will be outputted to `{OUT_PATH}`. |
| 59 | + |
| 60 | +Note: During evaluation the pairs are generated with a deterministic seed for the numpy random module, so as to allow comparisons between multiple evaluations. |
| 61 | + |
| 62 | +```shell |
| 63 | +python eval.py [-h] -v VAL_PATH -o OUT_PATH -c CHECKPOINT |
| 64 | + |
| 65 | +optional arguments: |
| 66 | + -h, --help show this help message and exit |
| 67 | + -v VAL_PATH, --val_path VAL_PATH |
| 68 | + Path to directory containing validation dataset. |
| 69 | + -o OUT_PATH, --out_path OUT_PATH |
| 70 | + Path for saving prediction images. |
| 71 | + -c CHECKPOINT, --checkpoint CHECKPOINT |
| 72 | + Path of model checkpoint to be used for inference. |
| 73 | +``` |
0 commit comments