Part of the series Learn TensorFlow Now
Over the next few posts, we’ll build a neural network that accurately reads handwritten digits. We’ll go step-by-step, starting with the basics of TensorFlow and ending up with one of the best networks in the ILSCRC 2013 image recognition competition.
The MNIST dataset is one of the simplest image datasets and makes for a perfect starting point. It consists of 70,000 images of handwritten digits. Our goal is to build a neural network that can identify the digit in a given image.
- 60,000 images in the training set
- 10,000 images in the test set
- Size: 28×28 (784 pixels)
- 1 Channel (ie. not RGB)
To start, we’ll import TensorFlow and our dataset:
TensorFlow makes it easy for us to download the MNIST dataset and save it locally. Our data has been split into a training set on which our network will learn and a test set against which we’ll check how well we’ve done.
Note: The labels are represented using one-hot encoding which means:
0 is represented by
1 0 0 0 0 0 0 0 0 0
1 is represented by
0 1 0 0 0 0 0 0 0 0
9 is represented by
0 0 0 0 0 0 0 0 0 1
Note: By default, the images are represented as arrays of 784 values. Below is a sample of what this might look like for a given image:
There are two steps to follow when training our own neural networks with TensorFlow:
- Create a computational graph
- Run data through the graph so our network can learn or make predictions
Creating a Computational Graph
We’ll start by creating the simplest possible computational graph. Notice in the following code that there is nothing that touches the actual MNIST data. We are simply creating a computational graph so that we may later feed our data to it.
For first-time TensorFlow users there’s a lot to unpack in the next few lines, so we’ll take it slow.
Before explaining anything, let’s take a quick look at the network we’ve created. Below are two different visualizations of this network at different granularities that tell slightly different stories about what we’ve created.
Left: A functional visualization of our single layer network. The 784
input values are each multiplied by a weight which feeds into our ten
Right: The graph created by TensorFlow, including nodes that represent our
cost.The first two lines of our code simply define a TensorFlow graph and tell TensorFlow that all the following operations we define should be included in this graph.
Next, we use
tf.Placeholder to create two “Placeholder” nodes in our graph. These are nodes for which we’ll provide values every time we run our network. Our placeholders are:
inputwhich will contain batches of 100 images, each with 784 values
labelswhich will contain batches of 100 labels, each with 10 values
Next we use
tf.Variable to create two new nodes,
layer1_biases. These represent parameters that the network will adjust as we show it more and more examples. To start, we’ve made
layer1_weights completely random, and
layer1_biases all zero. As we learn more about neural networks, we’ll see that these aren’t the greatest choice, but they’ll work for now.
After creating our weights, we’ll combine them using
tf.matmul to matrix multiply them against our input and
+ to add this result to our bias. You should note that
+ is simply a convenient shorthand for
tf.add. We store the result of this operation in
logits and will consider the output node with the highest value to be our network’s prediction for a given example.
Now that we’ve got our predictions, we want to compare them to the labels and determine how far off we were. We’ll do this by taking the softmax of our output and then use cross entropy as our measure of “loss” or
cost. We can perform both of these steps using
tf.nn.softmax_cross_entropy_with_logits. Now we’ve got a measure of loss for all the examples in our batch, so we’ll just take the mean of these as our final
The final step is to define an
optimizer. This creates a node that is responsible for automatically updating the
tf.Variables (weights and biases) of our network in an effort to minimize
cost. We’re going to use the vanilla of optimizers:
tf.train.GradientDescentOptimizer. Note that we have to provide a
learning_rate to our optimizer. Choosing an appropriate learning rate is one of the difficult parts of training any new network. For now we’ll arbitrarily use 0.01 because it seems to work reasonably well.
Running our Neural Network
Now that we’ve created the network it’s time to actually run it. We’ll pass 100 images and labels to our network and watch as the cost decreases.
The first line creates a TensorFlow Session for our
session is used to actually run the operations defined in our graph and produce results for us.
The second line initializes all of our
tf.Variables. In our example, this means choosing random values for
layer1_weights and setting
layer1_bias to all zeros.
Next, we create a loop that will run for 1,000 training steps with a
batch_size of 100. The first three lines of the loop simply select out 100 images and labels at a time. We store
feed_dict. Note that the keys of this dictionary
labels correspond to the
tf.Placeholder nodes we defined when creating our graph. These names must match, and all placeholders must have a corresponding entry in
Finally, we run the network using
session.run where we pass in
feed_dict. Notice that we also pass in
cost. This tells TensorFlow to evaluate these nodes and to store the results from the current run in
c. In the next post, we’ll touch more on this method, and how TensorFlow executes operations based on the nodes we supply to it here.
Now that we’ve put it all together, let’s look at the (truncated) output:
Cost: 12.673884 Cost: 11.534428 Cost: 8.510129 Cost: 9.842179 Cost: 11.445622 Cost: 8.554568 Cost: 9.342157 ... Cost: 4.811098 Cost: 4.2431364 Cost: 3.4888883 Cost: 3.8150232 Cost: 4.206609 Cost: 3.2540445
Clearly the cost is going down, but we still have many unanswered questions:
- What is the accuracy of our trained network?
- How do we know when to stop training? Was 1,000 steps enough?
- How can we improve our network?
- How can we see what its predictions actually were?
We’ll explore these questions in the next few posts as we seek to improve our performance.