Part of the series Learn TensorFlow Now
At the conclusion of the previous post, we realized that our first convolutional net wasn’t performing very well. It had a comparatively high cost (something we hadn’t seen before) and was performing slightly worse than a fully-connected network with the same number of layers.
Test results from 4-layer fully connected network:
Test Cost: 107.98408660641905 Test accuracy: 85.74999994039536 %
Test results from 4-layer Conv Net:
Test Cost: 15083.0833307 Test accuracy: 81.8799999356 %
As a refresher, here’s a visualization of the 4-layer ConvNet we built in the last post:

input
through pool2
(Click to enlarge).So how do we figure out what’s broken?
When writing any typical program we might fire up a debugger or even just use something like printf()
to figure out what’s going on. Unfortunately neural networks makes this very difficult for us. We can’t really step through thousands of multiplication, addition and ReLU operations and expect to glean much insight. One common debugging technique is to visualize all of the intermediate outputs and try to see if there are any obvious problems.
Let’s take a look at a histogram of the outputs of each layer before they’re passed through the ReLU non-linearity. (Remember, the ReLU operation simply chops off all negative values).
If you look closely at the above plots you’ll notice that the variance increases substantially at each layer (TensorBoard doesn’t let me adjust the scales of each plot so it’s not immediately obvious). The majority of outputs at layer1_conv
are within the range [-1,1]
, but by the time we get to layer4_conv
the outputs vary between [-20,000, 20,000]
. If we continue adding layers to our network this trend will continue and eventually our network will run into problems with overflow. In general we’d prefer our intermediate outputs to remain within some fixed range.
How does this relate to our high cost? Let’s take a look at the values of our logits
and predictions
. Recall that these values are calculated via:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#We flatten the last layer (pool2) and multiply it by a set of weights to produce 10 logits | |
shape = pool2.shape.as_list() | |
fc = shape[1] * shape[2] * shape[3] #7x7x256 = 6,272 | |
reshape = tf.reshape(pool2, [–1, fc]) | |
fc_weights = tf.Variable(tf.random_normal([fc, 10])) #6,272×10 | |
fc_bias = tf.Variable(tf.zeros([10])) #10 | |
#Logits are ten numbers | |
logits = tf.matmul(reshape, fc_weights) + fc_bias #10 | |
#Predictions are ten numbers that are scaled to add to 1.00 | |
predictions = tf.nn.softmax(logits) #10 |
The first thing to notice is that like the previous layers, the values of logits
have a large variance with some values in the hundreds of thousands. The second thing to notice is that once we take the softmax of logits
to create predictions
all of our values are reduced to either 1
or 0
. Recall that tf.nn.softmax
takes logits
and ensures that the ten values add up to 1 and that each value represents the probability a given image is represented by each digit. When some of our logits
are tens of thousands of times bigger than the others, these values end up dominating the probabilities.
The visualization of predictions
tells us that our network is super confident about the predictions it’s making. Essentially our network is claiming that it is 99% sure of its predictions. Whenever our network makes a mistake it is making a huge mistake and receives a large cost
penalty for it.
The problem with increasing (magnitude) intermediate outputs translates directly into an increased cost
. So how do we fix this? We want restrict the magnitude of the intermediate outputs of our network so they don’t increase so drastically at each layer.
Smaller Initial Weights
Recall that each convolution operation takes the dot product of our weights with a portion of the input. Basically, we’re multiplying and adding up a bunch of numbers similar to the following:
w0*i0 + w1*i1 + w2*i2 + … wn*in = output
Where:
- wx – Represents a single weight
- ix – Represents a single input (eg. pixel)
- n – The number of weights
One way to reduce the magnitude of this expression is to reduce the magnitude of all of our weights by some factor:
0.01*w0*i0 + 0.01*w1*i1 + 0.01*w2*i2 + … 0.01*wn*in = 0.01*output
Let’s try it and see if it works! We’ll modify the creation of our weights by multiplying them all by 0.01
. Therefore layer1_weights
would now be defined as:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layer1_weights = tf.Variable(tf.random_normal([3, 3, 1, 64]) * 0.01) |
After changing all five sets of weights (don’t forget about the fully-connected layer at the end), we can run our network and see the following test cost and accuracies:
Test Cost: 2.3025865221 Test accuracy: 5.01999998465 %
Yikes! The cost has decreased quite a bit, but that accuracy is abysmal… What’s going on this time? Let’s take a look at the intermediate outputs of the network:
If you look closely at the scales, you’ll see that this time the intermediate outputs are decreasing! The first layer’s outputs lie largely within the interval [-0.02, 0.02]
while the fourth layer generates outputs that lie within [-0.0002, 0.0002]
. This is essentially the opposite of the problem we saw before.
Let’s also examine the logits
and predictions
as we did before:
This time the logits
vary over a very small interval [-0.003, 0.003]
and predictions are completely uniform. The predictions appear to be centered around 0.10
which seems to indicate that our network is simply predicting each of the ten digits with 10% probability. In other words, our network is learning nothing at all and we’re in an even worse state than before!
Choosing the Perfect Initial Weights
What we’ve learned so far:
- Large initial weights lead to very large output in intermediate layers and an over-confident network.
- Small initial weights lead to very small output in intermediate layers and a network that doesn’t learn anything.
So how do we choose initial weights that are not too small and not too large? In 2013, Xavier Glorot and Yoshua Bengio published Understanding the difficulty of training deep forward neural networks in which they proposed initializing a set of weights based on how many input and output nerons are present for a given weight. For more on this initialization scheme see An Explanation of Xavier Initialization. This initialization scheme is called Xavier Initialization.
It turns out that Xavier Initialization does not work for layers using the asymmetric ReLU activation function. So while we can use it on our fully connected layer we can’t use it for our intermediate layers. However in 2015 Microsoft Research (Kaiming He et al.) published Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. In this paper they introduced a modified version of Xavier Initialization called Variance Scaling Initialization.
The math behind these initialization schemes is out of scope for this post, but TensorFlow makes them easy to use. I recommend simply remembering:
- Use Xavier Initialization in the fully-connected layers of your network. (Or layers that use softmax/tanh activation functions)
- Use Variance Scaling Initialization in the intermediate layer of your network that use ReLU activation functions.
We can modify the initialization of layer1_weights
from tf.random.normal
to use tf.contrib.layers.variance_scaling_initializer()
as follows:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layer1_weights = tf.get_variable("layer1_weights", [3, 3, 1, 64], initializer=tf.contrib.layers.variance_scaling_initializer()) |
We can also modify the fully connected layer’s weights to use tf.contrib.xavier_initializer
as follows:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fully_connected_weights = tf.get_variable("fully_connected_weights", [fc, 10], initializer=tf.contrib.layers.xavier_initializer()) |
There are a few small changes to note here. First, we use tf.get_variable
instead of calling tf.Variable
directly. This allows us to pass in a custom initializer for our weights. Second, we have to provide a unique name for our variable. Typically I just use the same name as my variable name.
If we continue changing all the weights in our network and run it, we can see the following output:
Cost: 2.49579 Accuracy: 9.00000035763 % Cost: 1.05762 Accuracy: 77.999997139 % ... Cost: 0.110656 Accuracy: 94.9999988079 % Test Cost: 0.0945288215741 Test accuracy: 97.2900004387 %
Much better! This is a big improvement over our previous results and we can see that both cost and accuracy have improved substantially. For the sake of curiosity, let’s look at the intermediate outputs of our network:
This looks much better. The variance of the intermediate values appears to increase only slightly as we move through the layers and all values are within about an order of magnitude of one another. While we can’t make any claims about the intermediate outputs being “perfect” or even “good”, we can at least rest assured that there is no glaringly obvious problems with them. (Sidenote: This seems to be a common theme in deep learning: We usually can’t prove we’ve done things correctly, we can only look for signs that we’ve done them incorrectly).
Thoughts on Weights
Hopefully I’ve managed to convince you of the importance of choosing good initial weights for a neural network. Fortunately when it comes to image recognition, there are well-known initialization schemes that pretty much solve this problem for us.
The problems with weight initialization should highlight the fragility of deep neural networks. After all, we would hope that even if we choose poor initial weights, after enough time our gradient descent optimizer would manage to correct them and settle on good values for our weights. Unfortunately that doesn’t seem to be the case, and our optimizer instead settles into a relatively poor local minima.
Complete Code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import shutil | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) | |
train_images = np.reshape(mnist.train.images, (–1, 28, 28, 1)) | |
train_labels = mnist.train.labels | |
test_images = np.reshape(mnist.test.images, (–1, 28, 28, 1)) | |
test_labels = mnist.test.labels | |
graph = tf.Graph() | |
with graph.as_default(): | |
input = tf.placeholder(tf.float32, shape=(None, 28, 28, 1)) | |
labels = tf.placeholder(tf.float32, shape=(None, 10)) | |
layer1_weights = tf.get_variable("layer1_weights", [3, 3, 1, 64], initializer=tf.contrib.layers.variance_scaling_initializer()) | |
layer1_bias = tf.Variable(tf.zeros([64])) | |
layer1_conv = tf.nn.conv2d(input, filter=layer1_weights, strides=[1,1,1,1], padding='SAME') | |
layer1_out = tf.nn.relu(layer1_conv + layer1_bias) | |
layer2_weights = tf.get_variable("layer2_weights", [3, 3, 64, 64], initializer=tf.contrib.layers.variance_scaling_initializer()) | |
layer2_bias = tf.Variable(tf.zeros([64])) | |
layer2_conv = tf.nn.conv2d(layer1_out, filter=layer2_weights, strides=[1,1,1,1], padding='SAME') | |
layer2_out = tf.nn.relu(layer2_conv + layer2_bias) | |
pool1 = tf.nn.max_pool(layer2_out, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') | |
layer3_weights = tf.get_variable("layer3_weights", [3, 3, 64, 128], initializer=tf.contrib.layers.variance_scaling_initializer()) | |
layer3_bias = tf.Variable(tf.zeros([128])) | |
layer3_conv = tf.nn.conv2d(pool1, filter=layer3_weights, strides=[1,1,1,1], padding='SAME') | |
layer3_out = tf.nn.relu(layer3_conv + layer3_bias) | |
layer4_weights = tf.get_variable("layer4_weights", [3, 3, 128, 128], initializer=tf.contrib.layers.variance_scaling_initializer()) | |
layer4_bias = tf.Variable(tf.zeros([128])) | |
layer4_conv = tf.nn.conv2d(layer3_out, filter=layer4_weights, strides=[1,1,1,1], padding='SAME') | |
layer4_out = tf.nn.relu(layer4_conv + layer4_bias) | |
pool2 = tf.nn.max_pool(layer4_out, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') | |
shape = pool2.shape.as_list() | |
fc = shape[1] * shape[2] * shape[3] | |
reshape = tf.reshape(pool2, [–1, fc]) | |
fully_connected_weights = tf.get_variable("fully_connected_weights", [fc, 10], initializer=tf.contrib.layers.xavier_initializer()) | |
fully_connected_bias = tf.Variable(tf.zeros([10])) | |
logits = tf.matmul(reshape, fully_connected_weights) + fully_connected_bias | |
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)) | |
learning_rate = 0.001 | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) | |
#Add a few nodes to calculate accuracy and optionally retrieve predictions | |
predictions = tf.nn.softmax(logits) | |
correct_prediction = tf.equal(tf.argmax(labels, 1), tf.argmax(predictions, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
with tf.Session(graph=graph) as session: | |
tf.global_variables_initializer().run() | |
num_steps = 5000 | |
batch_size = 100 | |
for step in range(num_steps): | |
offset = (step * batch_size) % (train_labels.shape[0] – batch_size) | |
batch_images = train_images[offset😦offset + batch_size), :] | |
batch_labels = train_labels[offset😦offset + batch_size), :] | |
feed_dict = {input: batch_images, labels: batch_labels} | |
_, c, acc = session.run([optimizer, cost, accuracy], feed_dict=feed_dict) | |
if step % 100 == 0: | |
print("Cost: ", c) | |
print("Accuracy: ", acc * 100.0, "%") | |
#Test | |
num_test_batches = int(len(test_images) / 100) | |
total_accuracy = 0 | |
total_cost = 0 | |
for step in range(num_test_batches): | |
offset = (step * batch_size) % (train_labels.shape[0] – batch_size) | |
batch_images = test_images[offset😦offset + batch_size)] | |
batch_labels = test_labels[offset😦offset + batch_size)] | |
feed_dict = {input: batch_images, labels: batch_labels} | |
c, acc = session.run([cost, accuracy], feed_dict=feed_dict) | |
total_cost = total_cost + c | |
total_accuracy = total_accuracy + acc | |
print("Test Cost: ", total_cost / num_test_batches) | |
print("Test accuracy: ", total_accuracy * 100.0 / num_test_batches, "%") |