nasnet-tensorflow
A nasnet in tensorflow based on tensorflow slim library.
About Nasnet and this repository
Nasnet is so far the state-of-the-art image classification architecture on ImageNet dataset (ArXiv release date is 21 Jul. 2017), the single crop accuracy for nasnet-large model is reported to be 82.7. For details of nasnet, please refer to paper Learning Transferable Architectures for Scalable Image Recognition by Barret Zoph etc.
With this repo., you should be able to:
-
Train a nasnet with customized dataset for image classification task from scratch. (If you want)
-
Finetune nasnet (nasnet-a-large, nasnet-a-mobile) from ImageNet pre-train model for image classification task.
-
Test and evaluate the model you have trained.
-
Deploy the model for your application or transfer the feature extractor to other tasks such as object detection. (By yourself)
Suitable for those who have solid CNN knowledge, python and tensorflow background. For those who have less background, tensorflow slim walk through tutorial should be a good start.
Dependencies
tensorflow >= 1.4.0
tf.contrib.slim
numpy
Usage
Clone the repo and enter workspace.
git clone https://github.com/yeephycho/nasnet-tensorflow.git
cd nasnet-tensorflow
mkdir train pre-trained
Download and converting to TFRecord format (This part is the same as tf.slim tutorial)
Many people would be interested in training Nasnet with their own data. I'm not sure whether it's a good idea to promote my repo. by using a dataset that provided by google's tutorial. Many people see the tfrecord generation code is a copy of tensorflow's solution, they just give up or send me an e-mail to ask how to train on customized dataset as I promised. However, if you spend some time on code, you would be able to find out that it may not be very easy to hard coding the tfrecord generation script by yourself but it's really easy for you to modify the template code and flower dataset is a very very good template for you to modify. So, before sending me e-mail, please spend half a hour on the following scripts:
train_image_classifier.py
download_and_convert_data.py
datasets/dataset_factory.py
datasets/download_and_convert_flowers.py
datasets/flowers.py
Just by modifing a few characters, you would be able to turn your own dataset into tfrecords.
The following instruction will lead you to generate tutorial tfrecords.
For each dataset, we'll need to download the raw data and convert it to TensorFlow's native TFRecord format. Each TFRecord contains a TF-Example protocol buffer. Below we demonstrate how to do this for the Flowers dataset.
$ DATA_DIR=/tmp/data/flowers
$ python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir="${DATA_DIR}"
When the script finishes you will find several TFRecord files created:
$ ls ${DATA_DIR}
flowers_train-00000-of-00005.tfrecord
...
flowers_train-00004-of-00005.tfrecord
flowers_validation-00000-of-00005.tfrecord
...
flowers_validation-00004-of-00005.tfrecord
labels.txt
These represent the training and validation data, sharded over 5 files each.
You will also find the $DATA_DIR/labels.txt
file which contains the mapping
from integer labels to class names.
I provide a user friendly version of tfrecord generation solution here. All you need to modify is located at
datasets/customized.py
Line 36, the number of training set and validation (test) set.
Line 39, the number of total classes.
datasets/convert_customized.py
Line 61, the number of validation (test) set.
# Create directories that name after the labels, then put the images under the label folders.
# ls /path/to/your/dataset/
# label0, label1, label2, ...
# ls /path/to/your/dataset/label0
# label0_image0.jpg, label0_image1.jpg, ...
#
# Image file name doesn't really matter.
DATASET_DIR=/path/to/your/own/dataset/
# Convert the customized data into tfrecords. Be noted that the dataset_name must be "customized"!
python convert_customized_data.py \
--dataset_name=customized \
--dataset_dir="${DATASET_DIR}"
Train from scratch
DATASET_DIR=/tmp/data/flowers # /path/to/your/own/dataset/
TRAIN_DIR=./train
# For Nasnet-a-mobile
# --dataset_name=customized
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=nasnet_mobile
# For Nasnet-a-large
# --dataset_name=customized
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=nasnet_large
Finetune from ImageNet pre-trained checkpoint
# This script will download pre-trained model from google, mv the file to pre-trained folder and unzip the file.
sh download_pretrained_model.sh
DATASET_DIR=/tmp/data/flowers # /path/to/your/own/dataset/
TRAIN_DIR=./train
# For Nasnet-a-mobile
# --dataset_name=customized
CHECKPOINT_PATH=./pre-trained/nasnet-a_mobile_04_10_2017/model.ckpt
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=nasnet_mobile \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=final_layer,aux_7 \
--trainable_scopes=final_layer,aux_7
# For Nasnet-a-large
# --dataset_name=customized
CHECKPOINT_PATH=./pre-trained/nasnet-a_large_04_10_2017/model.ckpt
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=nasnet_large \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=final_layer,aux_11 \
--trainable_scopes=final_layer,aux_11
Evaluation
A nasnet finetuned model for flowers dataset can be downloaded here from google drive.
# Please specify the model.ckpt-xxxx file by yourself, for example
CHECKPOINT_FILE=./train/model.ckpt-29735
# For Nasnet-a-mobile
# --dataset_name=customized
python eval_image_classifier.py \
--alsologtostderr \
--checkpoint_path=${CHECKPOINT_FILE} \
--dataset_dir=/tmp/data/flowers \
--dataset_name=flowers \
--dataset_split_name=validation \
--model_name=nasnet_mobile
# For Nasnet-a-large
# --dataset_name=customized
python eval_image_classifier.py \
--alsologtostderr \
--checkpoint_path=${CHECKPOINT_FILE} \
--dataset_dir=/tmp/data/flowers \
--dataset_name=flowers \
--dataset_split_name=validation \
--model_name=nasnet_large
Visualize the training progress
tensorboard --logdir=./train
Inference
Export static computation graph
Find file export_inference_graph.py
Change line 112, num_classes=5 (For flowers dataset, change this parameter according to your own dataset)
# For large model
python export_inference_graph.py \
--alsologtostderr \
--model_name=nasnet_large \
--output_file=./inference/nasnet_large_inf_graph.pb
# For mobile model
python export_inference_graph.py \
--alsologtostderr \
--model_name=nasnet_mobile \
--output_file=./inference/nasnet_mobile_inf_graph.pb
Freeze the graph and checkpoint model
# For large model
python freeze_graph.py \
--input_graph=./inference/nasnet_large_inf_graph.pb \
--input_checkpoint=./train/model.ckpt-16547 \
--input_binary=true \
--output_graph=./inference/frozen_nasnet_large.pb \
--output_node_names=final_layer/predictions
# For mobile model
python freeze_graph.py \
--input_graph=./inference/nasnet_mobile_inf_graph.pb \
--input_checkpoint=./train/yourOwnModel.ckpt-xxxx \
--input_binary=true \
--output_graph=./inference/frozen_nasnet_mobile.pb \
--output_node_names=final_layer/predictions
Inference demo
Note:
-
Make sure input image has a size that larger than 331 by 331, no resize operation for current input, you can add if you want.
-
Only support jpeg for inference demo.
-
Five crops inference strategy is adopted.
python image_inference.py
Output
Reference
Learning Transferable Architectures for Scalable Image Recognition