LTFN 9: Saving and Restoring

Part of the series Learn TensorFlow Now

In the last post we looked at a modified version of VGGNet that achieved ~97.8% accuracy recognizing handwritten digits. Now that we’re relatively satisfied with our network, we’d like to save a trained version of the network that we can restore and use to classify digits whenever we’d like. We’ll do so by saving all of the tf.Variables() we’ve created to a checkpoint (.ckpt) file.

Saving a Checkpoint

When we save our computational graph, we serialize both the graph itself and the values of all of our parameters. When serializing nodes in our graph, TensorFlow keeps track of their names in order for us to interact with them later. Nodes that we don’t name will receive default names and be very hard to pick out. (While preparing this post I forgot to name input and labels which received the names Placeholder and Placeholder_1 instead). For this reason, we’ll take a minute to ensure that we give names to input, labels, cost, accuracy and predictions.

Saving a single checkpoint is straightforward. If we just want to save the state of our network after training then we simply add the following lines to the end of our previous network:

This snippet of code first creates a tf.train.Saver, an object that coordinates both saving and restoration of models. Next we call saver.save() passing in the current session. As a refresher, this session contains information about both the structure of the computational graph as well as the exact values of all parameters. By default the saver saves all tf.Variables() (weight/bias parameters) from our graph, but it also has the ability to save only portions of the graph.

After saving the checkpoint, the saver returns the save_path. Why return the save_path if we just provided it with a path? The saver also allows you to shard the saved checkpoint by device (eg. using multiple GPUs to train a model). In this situation, the returned save_path is appended with information on the number of shards created.

After running this code, we can navigate to the folder /tmp/vggnet/ and run ls -tralh to look at the contents:

-rw-rw-r--  1 jovarty jovarty 184M Mar 12 19:57 vgg_net.ckpt.data-00000-of-00001
-rw-rw-r--  1 jovarty jovarty 2.7K Mar 12 19:57 vgg_net.ckpt.index
-rw-rw-r--  1 jovarty jovarty  105 Mar 12 19:57 checkpoint
-rw-rw-r--  1 jovarty jovarty 188K Mar 12 19:57 vgg_net.ckpt.meta

The first file vgg_net.ckpt.data-00000-of-00001 is 184 MB in size and contains the values of all of our parameters. This is a reasonably large size and one of the reasons it’s nice to use networks with smaller numbers of parameters. This model is larger than most of the apps on my phone so it could be difficult to deploy to mobile devices.

The vgg_net.ckpt.meta file contains information on the structure of our computational graph and the names of all of our nodes. Later we’ll use this file to rebuild our computational graph from scratch.

 

Saving Multiple Checkpoints

Some neural networks are trained over the course of multiple weeks and we would like a way to periodically take checkpoints as our network learns. This allows us to go back in time and hand tune hyperparameters such as learning rate to try to squeeze the best performance out of our network. Fortunately, TensorFlow makes it easy to take checkpoints at any point during training. For example, we can modify our training loop to simply save a checkpoint whenever we print accuracy and cost.

The only real modification we’ve made here is to pass in global_step=step to track when each checkpoint was created. Be aware that this can eat up disk space relatively quickly depending on the size of your model. Each of our VGG checkpoints requires 184 MB of space.

 

Restoring a Model

Now that we know how to save our model’s parameters, how do we restore them? One way is to declare the original computational graph in Python and then restore the values to all the tf.Variables() (parameters) using tf.train.Saver.

For example, we could remove the training and testing code from our previous network and replace it with the following:

There are really only two additions to the code here:

  1. Create the tf.train.Saver()
  2. Restore the model to the current session. Note: This portion requires the graph to have been defined with identical names and parameters as when they were saved to a checkpoint.

Other than these changes, we test the network exactly as we would have before. If we wanted to test our network on new examples, we could load them into test_images and retrieve predictions from our graph instead of cost and accuracy.

This approach works well for networks we’ve built ourselves but it can be very cumbersome when we want to run networks designed by someone else. It takes hours to manually create each parameter and operation exactly as the original author had.

Restoring a Model from Scratch

One approach to using someone else’s neural network is to load up the computational graph defined in the .meta file before restoring the values to this graph from the .ckpt file. Below is a self-contained example of restoring a model from scratch:

There are a few subtle changes worth pointing out. First, we create our tf.train.Saver indirectly by importing the computational graph with tf.train.import_meta_graph(). Next, we restore the values to our computational graph with saver.restore() exactly as we had done previously.

Since we don’t have access to the input and labels nodes, we have to recover them from our graph with graph.get_tensor_by_name(). Notice that we are passing in the names that we had previously specified and appending :0 to these names. Some TensorFlow operations produce multiple outputs. When this happens, TensorFlow names them :0, :1 and so on until all the outputs have a unique name. All of the operations we’re using have only one output so we simply stick with :0.

Finally, the last change involves actually running the network. As in the previous step, we need to specify proper names for cost and accuracy because we don’t have direct access to the computational nodes. Fortunately, it’s simple to just pass in strings with the names 'cost:0' and 'accuracy:0' that specify which operations we want to run and return the values of. Alternatively, we could have recovered the nodes with graph.get_tensor_by_name() and passed them in directly.

Also note that if we had named our optimizer, we could have passed it into session.run() and continued to train our network. We could have even created a checkpoint of our saved network at this point if we decided it had improved in some way.

There are a variety of ways to save and restore models and we’ve really only scratched the surface. Below are a few self-contained examples of the various approaches we’ve looked at:

3 thoughts on “LTFN 9: Saving and Restoring

  1. Thanks for your great tutorial. I just come from matlab and i train the Xception model from scratch over my own dataset to have a pre-trained network belong to my image, so i have the (checkpoint, .ckpt, .meta) files in log folder. My question is how to extract features of my data uses this files?. I already restore the model as:
    “import tensorflow as tf
    with tf.Session() as sess:
    saver = tf.train.import_meta_graph(‘model.ckpt-6353.meta’)
    saver.restore(sess,tf.train.latest_checkpoint(r’C:\Users…’))”
    After this step i don’t know how to complete. My dataset is nine classes.

    1. I think you can either get access to tensors via something like:

      input = graph.get_tensor_by_name(“input:0”)

      or if you don’t know the name of your tensors you could probably try:

      allKeys = graph.get_all_collection_keys()
      print(allKeys)

      Then I think you should be able to use get_tensor_by_name() with the appropriate key.

      I actually haven’t done this myself but I’m going off of: https://www.tensorflow.org/api_docs/python/tf/Graph#get_all_collection_keys

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s