Part of the series Learn TensorFlow Now
In the last post we looked at a modified version of VGGNet that achieved ~97.8% accuracy recognizing handwritten digits. Now that we’re relatively satisfied with our network, we’d like to save a trained version of the network that we can restore and use to classify digits whenever we’d like. We’ll do so by saving all of the tf.Variables()
we’ve created to a checkpoint (.ckpt
) file.
Saving a Checkpoint
When we save our computational graph, we serialize both the graph itself and the values of all of our parameters. When serializing nodes in our graph, TensorFlow keeps track of their names in order for us to interact with them later. Nodes that we don’t name will receive default names and be very hard to pick out. (While preparing this post I forgot to name input
and labels
which received the names Placeholder
and Placeholder_1
instead). For this reason, we’ll take a minute to ensure that we give names to input
, labels
, cost
, accuracy
and predictions
.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
input = tf.placeholder(tf.float32, shape=(None, 28, 28, 1), name="input") | |
labels = tf.placeholder(tf.float32, shape=(None, 10), name="labels") | |
… | |
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels), name="cost") | |
… | |
predictions = tf.nn.softmax(logits, name="predictions") | |
correct_prediction = tf.equal(tf.argmax(labels, 1), tf.argmax(predictions, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") |
Saving a single checkpoint is straightforward. If we just want to save the state of our network after training then we simply add the following lines to the end of our previous network:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
saver = tf.train.Saver() #Create a saver | |
save_path = saver.save(session, "/tmp/vggnet/vgg_net.ckpt") #Specify where to save the model | |
print("Saved model at: ", save_path) #Confirm the saved location |
This snippet of code first creates a tf.train.Saver
, an object that coordinates both saving and restoration of models. Next we call saver.save()
passing in the current session
. As a refresher, this session
contains information about both the structure of the computational graph as well as the exact values of all parameters. By default the saver
saves all tf.Variables()
(weight/bias parameters) from our graph, but it also has the ability to save only portions of the graph.
After saving the checkpoint, the saver
returns the save_path
. Why return the save_path
if we just provided it with a path? The saver
also allows you to shard the saved checkpoint by device (eg. using multiple GPUs to train a model). In this situation, the returned save_path
is appended with information on the number of shards created.
After running this code, we can navigate to the folder /tmp/vggnet/
and run ls -tralh
to look at the contents:
-rw-rw-r-- 1 jovarty jovarty 184M Mar 12 19:57 vgg_net.ckpt.data-00000-of-00001 -rw-rw-r-- 1 jovarty jovarty 2.7K Mar 12 19:57 vgg_net.ckpt.index -rw-rw-r-- 1 jovarty jovarty 105 Mar 12 19:57 checkpoint -rw-rw-r-- 1 jovarty jovarty 188K Mar 12 19:57 vgg_net.ckpt.meta
The first file vgg_net.ckpt.data-00000-of-00001
is 184 MB
in size and contains the values of all of our parameters. This is a reasonably large size and one of the reasons it’s nice to use networks with smaller numbers of parameters. This model is larger than most of the apps on my phone so it could be difficult to deploy to mobile devices.
The vgg_net.ckpt.meta
file contains information on the structure of our computational graph and the names of all of our nodes. Later we’ll use this file to rebuild our computational graph from scratch.
Saving Multiple Checkpoints
Some neural networks are trained over the course of multiple weeks and we would like a way to periodically take checkpoints as our network learns. This allows us to go back in time and hand tune hyperparameters such as learning rate to try to squeeze the best performance out of our network. Fortunately, TensorFlow makes it easy to take checkpoints at any point during training. For example, we can modify our training loop to simply save a checkpoint whenever we print accuracy
and cost
.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
saver = tf.train.Saver() #Create saver | |
num_steps = 1000 | |
batch_size = 100 | |
for step in range(num_steps): | |
offset = (step * batch_size) % (train_labels.shape[0] – batch_size) | |
batch_images = train_images[offset😦offset + batch_size), :] | |
batch_labels = train_labels[offset😦offset + batch_size), :] | |
feed_dict = {input: batch_images, labels: batch_labels} | |
_, c, acc = session.run([optimizer, cost, accuracy], feed_dict=feed_dict) | |
if step % 100 == 0: | |
print("Cost: ", c) | |
print("Accuracy: ", acc * 100.0, "%") | |
saver.save(session, "/tmp/vggnet/vgg_net.ckpt", global_step=step) #Save session every 100 mini-batches |
The only real modification we’ve made here is to pass in global_step=step
to track when each checkpoint was created. Be aware that this can eat up disk space relatively quickly depending on the size of your model. Each of our VGG checkpoints requires 184 MB
of space.
Restoring a Model
Now that we know how to save our model’s parameters, how do we restore them? One way is to declare the original computational graph in Python and then restore the values to all the tf.Variables()
(parameters) using tf.train.Saver
.
For example, we could remove the training and testing code from our previous network and replace it with the following:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.Session() as session: | |
#Restore Model | |
saver = tf.train.Saver() #Create a saver (object to save/restore sessions) | |
saver.restore(session, "/tmp/vggnet/vgg_net.ckpt") #Restore the session from a previously saved checkpoint | |
#Now we test our restored model exactly as before | |
batch_size = 100 | |
num_test_batches = int(len(test_images) / 100) | |
total_accuracy = 0 | |
total_cost = 0 | |
for step in range(num_test_batches): | |
offset = (step * batch_size) % (train_labels.shape[0] – batch_size) | |
batch_images = test_images[offset😦offset + batch_size)] | |
batch_labels = test_labels[offset😦offset + batch_size)] | |
feed_dict = {input: batch_images, labels: batch_labels} | |
c, acc = session.run([cost, accuracy], feed_dict=feed_dict) | |
total_cost = total_cost + c | |
total_accuracy = total_accuracy + acc | |
print("Test Cost: ", total_cost / num_test_batches) | |
print("Test accuracy: ", total_accuracy * 100.0 / num_test_batches, "%") |
There are really only two additions to the code here:
- Create the
tf.train.Saver()
- Restore the model to the current
session
. Note: This portion requires the graph to have been defined with identical names and parameters as when they were saved to a checkpoint.
Other than these changes, we test the network exactly as we would have before. If we wanted to test our network on new examples, we could load them into test_images
and retrieve predictions
from our graph instead of cost
and accuracy
.
This approach works well for networks we’ve built ourselves but it can be very cumbersome when we want to run networks designed by someone else. It takes hours to manually create each parameter and operation exactly as the original author had.
Restoring a Model from Scratch
One approach to using someone else’s neural network is to load up the computational graph defined in the .meta
file before restoring the values to this graph from the .ckpt
file. Below is a self-contained example of restoring a model from scratch:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) | |
test_images = np.reshape(mnist.test.images, (–1, 28, 28, 1)) | |
test_labels = mnist.test.labels | |
graph = tf.Graph() | |
with tf.Session(graph=graph) as session: | |
saver = tf.train.import_meta_graph('/tmp/vggnet/vgg_net.ckpt.meta') #Create a saver based on a saved graph | |
saver.restore(session, '/tmp/vggnet/vgg_net.ckpt') #Restore the values to this graph | |
input = graph.get_tensor_by_name("input:0") #Get access to the input node | |
labels = graph.get_tensor_by_name("labels:0") #Get access to the labels node | |
batch_size = 100 | |
num_test_batches = int(len(test_images) / 100) | |
total_accuracy = 0 | |
total_cost = 0 | |
for step in range(num_test_batches): | |
offset = (step * batch_size) % (test_labels.shape[0] – batch_size) | |
batch_images = test_images[offset😦offset + batch_size)] | |
batch_labels = test_labels[offset😦offset + batch_size)] | |
feed_dict = {input: batch_images, labels: batch_labels} | |
c, acc = session.run(['cost:0', 'accuracy:0'], feed_dict=feed_dict) #Note: We pass in strings 'cost:0' and 'accuracy:0' | |
total_cost = total_cost + c | |
total_accuracy = total_accuracy + acc | |
print("Test Cost: ", total_cost / num_test_batches) | |
print("Test accuracy: ", total_accuracy * 100.0 / num_test_batches, "%") |
There are a few subtle changes worth pointing out. First, we create our tf.train.Saver
indirectly by importing the computational graph with tf.train.import_meta_graph()
. Next, we restore the values to our computational graph with saver.restore()
exactly as we had done previously.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
saver = tf.train.import_meta_graph('/tmp/vggnet/vgg_net.ckpt.meta') | |
saver.restore(session, '/tmp/vggnet/vgg_net.ckpt') |
Since we don’t have access to the input
and labels
nodes, we have to recover them from our graph with graph.get_tensor_by_name()
. Notice that we are passing in the names that we had previously specified and appending :0
to these names. Some TensorFlow operations produce multiple outputs. When this happens, TensorFlow names them :0
, :1
and so on until all the outputs have a unique name. All of the operations we’re using have only one output so we simply stick with :0
.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
input = graph.get_tensor_by_name("input:0") | |
labels = graph.get_tensor_by_name("labels:0") |
Finally, the last change involves actually running the network. As in the previous step, we need to specify proper names for cost
and accuracy
because we don’t have direct access to the computational nodes. Fortunately, it’s simple to just pass in strings with the names 'cost:0'
and 'accuracy:0'
that specify which operations we want to run and return the values of. Alternatively, we could have recovered the nodes with graph.get_tensor_by_name()
and passed them in directly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
c, acc = session.run(['cost:0', 'accuracy:0'], feed_dict=feed_dict) |
Also note that if we had named our optimizer
, we could have passed it into session.run()
and continued to train our network. We could have even created a checkpoint of our saved network at this point if we decided it had improved in some way.
There are a variety of ways to save and restore models and we’ve really only scratched the surface. Below are a few self-contained examples of the various approaches we’ve looked at:
Thanks for your great tutorial. I just come from matlab and i train the Xception model from scratch over my own dataset to have a pre-trained network belong to my image, so i have the (checkpoint, .ckpt, .meta) files in log folder. My question is how to extract features of my data uses this files?. I already restore the model as:
“import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph(‘model.ckpt-6353.meta’)
saver.restore(sess,tf.train.latest_checkpoint(r’C:\Users…’))”
After this step i don’t know how to complete. My dataset is nine classes.
I think you can either get access to tensors via something like:
input = graph.get_tensor_by_name(“input:0”)
or if you don’t know the name of your tensors you could probably try:
allKeys = graph.get_all_collection_keys()
print(allKeys)
Then I think you should be able to use get_tensor_by_name() with the appropriate key.
I actually haven’t done this myself but I’m going off of: https://www.tensorflow.org/api_docs/python/tf/Graph#get_all_collection_keys
I don’t figure the solution. Could you provide some portion of code?