You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After running this command, you will find the generated csv and tfrecords (.record) files located in [data/raccoon_data](data/raccoon_data).
200
+
After running this command, you will find the generated csv and tfrecords (`.record` or `.tfrecord`) files located in [data/raccoon_data](data/raccoon_data).
198
201
Et voila, we have the tfrecord files generated, and we can use it in next steps for training.
## Training object detection model with your custom dataset
202
206
203
-
TODO
207
+
To start training our model, we need to prepare a configuration file specifying the backbone model and all the required parameters for training and evaluation.
208
+
In this [tutorial](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/configuring_jobs.md) from the object detection api you can find explanation of all the required parameters.
209
+
But fortunately, they also provide us with many [example config files](https://github.com/tensorflow/models/tree/master/research/object_detection/configs/tf2) that we can use and just modify some parameters to match our requirements.
210
+
211
+
Here I will be using the the config file of the SSD model with MobileNetV2 backbone as it is small model that can fit in small GPU memory.
212
+
So let's first download the pretrained model with coco dataset that is provided in the [model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md), and use it as initialization to our model.
213
+
This is called fine-tuning, which is simply loading the weights of pretrained model, and use it as a starting point in our training. This will help us too much as we have very small number of images.
214
+
You can read more about transfer learning methods from [here](https://cs231n.github.io/transfer-learning/).
tar -xzvf ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz
222
+
```
223
+
224
+
Then you can download the original config file from [here](https://github.com/tensorflow/models/tree/master/research/object_detection/configs/tf2).
225
+
I downloaded [ssd_mobilenet_v2_320x320_coco17_tpu-8.config](https://github.com/tensorflow/models/blob/master/research/object_detection/configs/tf2/ssd_mobilenet_v2_320x320_coco17_tpu-8.config) and made the following changes:
226
+
227
+
* Changed `num_classes: 1` as we have only class (raccoon), instead of 90 classes in coco dataset.
228
+
* Changed `fine_tune_checkpoint_type: "classification"` to `fine_tune_checkpoint_type: "detection"` as we will be using the pre-trained detection model as initialization.
229
+
* Added the path of the pretrained model in the field `fine_tune_checkpoint:`, for example using the mobilenet v2 model I added `fine_tune_checkpoint: "../models/ssd_mobilenet_v2_320x320_coco17_tpu-8/checkpoint/ckpt-0"`
230
+
* Changed `batch_size: 512` and used a reasonable number to my GPU memory. I have a 4GB of GPU memory, so I am using `batch_size: 16`
231
+
* Added the maximum number of training iterations in `num_steps:`, and also use the same number in `total_steps:`
232
+
* Adapted the learning rate to our model and batch size (originally they used higher learning rates because they had bigger batch sizes). This values needs some testing and tuning, but finally I used this configuration:
233
+
```
234
+
cosine_decay_learning_rate {
235
+
learning_rate_base: 0.03
236
+
total_steps: 3000
237
+
warmup_learning_rate: 0.005
238
+
warmup_steps: 100 }
239
+
```
240
+
* The `label_map_path:` should point to your labelmap (here the raccoon labelmap) `label_map_path: "../models/raccoon_labelmap.pbtxt"`
241
+
* You need to set the `tf_record_input_reader` under both `train_input_reader` and `eval_input_reader`. This should point to the tfrecords we generated.
Yous should also prepare the labelmap according to your data. For our raccoon dataset, the labelmap file contains:
252
+
253
+
```
254
+
item {
255
+
id: 1
256
+
name: 'raccoon'
257
+
}
258
+
```
259
+
260
+
The labelmap file and the modified configuration files are added to this repo for convenience.
261
+
You can find them is [models/raccoon_labelmap.pbtxt](models/raccoon_labelmap.pbtxt) and [models/ssd_mobilenet_v2_raccoon.config](models/ssd_mobilenet_v2_raccoon.config).
262
+
263
+
Once you prepare the configuration file, you can start training by typing the following commands:
264
+
265
+
```bash
266
+
cd train_tf2
267
+
bash start_train.sh
268
+
```
269
+
270
+
The [start_train.sh](train_tf2/start_train.sh) file is a simple shell script that contains all the parameters needed for training, and runs the training script.
It is also recommended to run the validation script along with the training scripts.
282
+
The training script saves a checkpoint every _n_ steps while training, and this value can be specified in the parameter `--checkpoint_every_n`.
283
+
While training is running, the validation script reads these checkpoints once they are available, and use them to evaluate the model with the validation set.
284
+
This will help us to monitor the training progress by printing the values on the terminal, or by using a GUI monitoring package like [tensorboard](https://www.tensorflow.org/tensorboard/get_started) as we will see.
285
+
286
+
To run the validation script along with training script, open another terminal and run:
287
+
288
+
```bash
289
+
bash start_eval.sh
290
+
```
291
+
292
+
The [start_eval.sh](train_tf2/start_eval.sh) file contains:
0 commit comments