Cat-recognition-train
This repository demonstrates how to train a cat vs dog recognition model and export the model to an optimized frozen graph easy for deployment using TensorFlow. If you want to know how to deploy a flask app which recognizes cats/dogs using TensorFlow, please visit cat-recognition-app.
Requirements
- Python3 (Tested on 3.6.8)
- TensorFlow (Tested on 1.12.0)
- NumPy (Tested on 1.15.1)
- tqdm (Tested on 4.29.1)
- Dogs vs. Cats dataset from https://www.kaggle.com/c/dogs-vs-cats
- (Optional if you want to run tests) PyTorch (Tested on 1.0.0 and 1.0.1)
Build environment
We recommend using Anaconda3 / Miniconda3 to manage your python environment.
If the machine you're using does not have a GPU instance, you can just:
$ pip install -r requirements.txt
or
$ conda install --file requirements.txt
However, if you want to use GPU to accelerate the training process, please visit TensorFlow - GPU support for more information.
Train a Convolutional Neural Network
In this part, we will use TensorFlow to train a CNN to classify cats' images from dogs' image using Kaggle dataset Dogs vs. Cats. We will do the following things:
- Create training/valid set (dataset.py)
- Load, augment, resize and normalize the images using
tensorflow.data.Dataset
api. (dataset.py) - Define a CNN model (net.py)
- Here we use the ShufflenetV2 structure, which achieves great balance between speed and accuracy.
- We do transfer learning on ShuffleNetV2 using the pretrained weights from https://github.com/ericsun99/Shufflenet-v2-Pytorch.
- If you want to know how to load PyTorch weights onto TensorFlow model graph, please check
convert_pytorch_weight_test
starting from line 44 inmodule_tests.py
.
- Train the CNN model (train.py)
- Serialize the model for deployment (train.py)
If you want to execute the code, make sure you have all package requirements installed, and Dogs vs. Cats training dataset placed in datasets
. The folder structure should be like:
cat-recognition-train
+-- train.py
+-- net.py
+-- dataset.py
+-- datasets
+-- train
| +-- cat.0.jpg
| +-- cat.1.jpg
| ...
| +-- cat.12499.jpg
| +-- dog.0.jpg
| +-- dog.1.jpg
| ...
| +-- dog.12499.jpg
+-- ...
After all requirements set, run the following command using default arguments:
$ python train.py
Or you can pass your desired arguments:
$ python train.py --epochs 30 --batch_size 32 --valset_ratio .1 --optim sgd --lr_decay_step 10
See train.py
for available arguments.
Visualizing Learning using Tensorboard
During training, you can supervise how is the training going by running:
$ tensorboard --logdir runs
And you can check the tensorboard summaries on localhost:6006
.
Training and Validation Flow
Whole training and validation flow, including CNN model and other training/validation operations like optimizer, saver, accuracy counter, etc
Model Performance
Validation/Train loss and validation accuracy on each epoch
Optimized Network Graph
Predict Using Optimized Frozen Graph
See predict.py
for details and demo.
Default image used for predict.py demo
You can run
$ python predict.py
The result should be:
Predicting catness on images/test.png using model from baseline_model/optimized_net_best_acc.pb
Catness: 16.460064
Cat Probability: 1.000000
It's a cat.
for demonstration. Also, if you have your own cat / dog photo for testing, run
$ python predict.py --path path/to/your/img.png
PNGs, JPGs, BMPs are supported.