Skip to content

anderspkd/tf_train_quantized

Repository files navigation

Quantize aware training and conversion to tflite files

This folder contains a couple of scripts that can be used to train models and convert the result into a tflite files.

Currently, only MNIST models are supported, but it should be easy to use train_mnist.py as a starting point for other problem areas (e.g., CIFAR10).

quickstart

Run

$ virtualenv --python $(which python3)
$ source venv/bin/activate
$ pip install tensorflow

And then, for example,

$ ./train_and_quantize mnist_simple1

Consult Models.py to see which models can be trained.

Step-by-step

Training

The file train_mnist.py specifies a script that performs quantization aware training given the name of one of the models defined in models.py (see bit about models further down).

Example usage:

$ ./train_mnist.py
usage: train_mnist.py [-h] [-m name] [-l] [--epochs epochs]
			  [--checkpoint-dir dir] [--freeze model name]

trains a model with quantization aware training

optional arguments:
  -h, --help            show this help message and exit
  -m name, --model-name name
			    name of model. Must be defined in models.py
  -l, --list-models     lists available models
  --epochs epochs       number of epochs for training. Default is 1
  --checkpoint-dir dir  directory to save checkpoint information. Default is
			    "./chkpt/checkpoints"
  --freeze model name   freezes the model
$ ./train_mnist.py -m simple
[snip 🦀]

60000/60000 [==============================] - 2s 29us/sample - loss: 0.4387 - acc: 0.8832
10000/10000 [==============================] - 1s 54us/sample - loss: 0.2266 - acc: 0.9360

Evaluation results:
test loss 0.22657053579986094
test accuracy 0.936

This generates some checkpoints in the folder ./chkpt/ that will be needed later

Converting checkpoints to a frozen graph def

The script checkpoint2pb.py takes the checkpoints from the previous step and creates a frozen graph def file:

$ ./checkpoint2pb.py
usage: checkpoint2pb.py [-h] model_name checkpoints
checkpoint2pb.py: error: the following arguments are required: model_name, checkpoints
$ ./checkpoint2pb.py mnist_simple chkpt/checkpoints
[snip 🦀]
-----------------------
writing mnist_simple.pb
input_arrays: flatten_input
output_arrays: dense_1/Softmax

Notice that the model name needs to prefixed with the name of the group it belongs to (here mnist).

Converting a frozen graph to a tflite file

Finally, the .pb file can be converted into a fully optimized and quantized model file.

$ ./pb2tflite.sh
usage: ./pb2tflite.sh [frozen_model.pb] [input_arrays] [output_arrays]
$ ./pb2tflite.sh mnist_simple.pb flatten_input dense_1/Softmax
converting "mnist_simple.pb" to "mnist_simple.tflite"
[snip 🦀]
done!

Models

All models should be defined in models.py. See existing models for how to go about defining a new model.

Mnist

Two models are currently defined for the MNIST dataset:

Simple

Two fully connected layers with relu activation.

Simple2

One convolution and two fully connected layers with relu6 activation.

About

Quantized training using Keras

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published