Part of the series Learn TensorFlow Now
In the last post we looked at a modified version of VGGNet that achieved ~97.8% accuracy recognizing handwritten digits. Now that we’re relatively satisfied with our network, we’d like to save a trained version of the network that we can restore and use to classify digits whenever we’d like. We’ll do so by saving all of the
tf.Variables() we’ve created to a checkpoint (
Saving a Checkpoint
When we save our computational graph, we serialize both the graph itself and the values of all of our parameters. When serializing nodes in our graph, TensorFlow keeps track of their names in order for us to interact with them later. Nodes that we don’t name will receive default names and be very hard to pick out. (While preparing this post I forgot to name
labels which received the names
Placeholder_1 instead). For this reason, we’ll take a minute to ensure that we give names to
Saving a single checkpoint is straightforward. If we just want to save the state of our network after training then we simply add the following lines to the end of our previous network:
This snippet of code first creates a
tf.train.Saver, an object that coordinates both saving and restoration of models. Next we call
saver.save() passing in the current
session. As a refresher, this
session contains information about both the structure of the computational graph as well as the exact values of all parameters. By default the
saver saves all
tf.Variables() (weight/bias parameters) from our graph, but it also has the ability to save only portions of the graph.
After saving the checkpoint, the
saver returns the
save_path. Why return the
save_path if we just provided it with a path? The
saver also allows you to shard the saved checkpoint by device (eg. using multiple GPUs to train a model). In this situation, the returned
save_path is appended with information on the number of shards created.
After running this code, we can navigate to the folder
/tmp/vggnet/ and run
ls -tralh to look at the contents:
-rw-rw-r-- 1 jovarty jovarty 184M Mar 12 19:57 vgg_net.ckpt.data-00000-of-00001 -rw-rw-r-- 1 jovarty jovarty 2.7K Mar 12 19:57 vgg_net.ckpt.index -rw-rw-r-- 1 jovarty jovarty 105 Mar 12 19:57 checkpoint -rw-rw-r-- 1 jovarty jovarty 188K Mar 12 19:57 vgg_net.ckpt.meta
The first file
184 MB in size and contains the values of all of our parameters. This is a reasonably large size and one of the reasons it’s nice to use networks with smaller numbers of parameters. This model is larger than most of the apps on my phone so it could be difficult to deploy to mobile devices.
vgg_net.ckpt.meta file contains information on the structure of our computational graph and the names of all of our nodes. Later we’ll use this file to rebuild our computational graph from scratch.
Saving Multiple Checkpoints
Some neural networks are trained over the course of multiple weeks and we would like a way to periodically take checkpoints as our network learns. This allows us to go back in time and hand tune hyperparameters such as learning rate to try to squeeze the best performance out of our network. Fortunately, TensorFlow makes it easy to take checkpoints at any point during training. For example, we can modify our training loop to simply save a checkpoint whenever we print
The only real modification we’ve made here is to pass in
global_step=step to track when each checkpoint was created. Be aware that this can eat up disk space relatively quickly depending on the size of your model. Each of our VGG checkpoints requires
184 MB of space.
Restoring a Model
Now that we know how to save our model’s parameters, how do we restore them? One way is to declare the original computational graph in Python and then restore the values to all the
tf.Variables() (parameters) using
For example, we could remove the training and testing code from our previous network and replace it with the following:
There are really only two additions to the code here:
- Create the
- Restore the model to the current
session. Note: This portion requires the graph to have been defined with identical names and parameters as when they were saved to a checkpoint.
Other than these changes, we test the network exactly as we would have before. If we wanted to test our network on new examples, we could load them into
test_images and retrieve
predictions from our graph instead of
This approach works well for networks we’ve built ourselves but it can be very cumbersome when we want to run networks designed by someone else. It takes hours to manually create each parameter and operation exactly as the original author had.
Restoring a Model from Scratch
One approach to using someone else’s neural network is to load up the computational graph defined in the
.meta file before restoring the values to this graph from the
.ckpt file. Below is a self-contained example of restoring a model from scratch:
There are a few subtle changes worth pointing out. First, we create our
tf.train.Saver indirectly by importing the computational graph with
tf.train.import_meta_graph(). Next, we restore the values to our computational graph with
saver.restore() exactly as we had done previously.
Since we don’t have access to the
labels nodes, we have to recover them from our graph with
graph.get_tensor_by_name(). Notice that we are passing in the names that we had previously specified and appending
:0 to these names. Some TensorFlow operations produce multiple outputs. When this happens, TensorFlow names them
:1 and so on until all the outputs have a unique name. All of the operations we’re using have only one output so we simply stick with
Finally, the last change involves actually running the network. As in the previous step, we need to specify proper names for
accuracy because we don’t have direct access to the computational nodes. Fortunately, it’s simple to just pass in strings with the names
'accuracy:0' that specify which operations we want to run and return the values of. Alternatively, we could have recovered the nodes with
graph.get_tensor_by_name() and passed them in directly.
Also note that if we had named our
optimizer, we could have passed it into
session.run() and continued to train our network. We could have even created a checkpoint of our saved network at this point if we decided it had improved in some way.
There are a variety of ways to save and restore models and we’ve really only scratched the surface. Below are a few self-contained examples of the various approaches we’ve looked at: