Deepak Subburam| Deep learning
Research notes

posted by Deepak on Jun 16, 2016

Pretty much everyone in the field has had a shot at training some artificial neural net on the MNIST database of handwritten digits—the 60,000 samples of 28x28 pixel grayscale images. The canonical task with this dataset is to predict the class label associated with each sample, one of {0, 1, ..., 9}. Here, I present something a bit more novel: a recurrent neural net (RNN) that learns an autoencoding of the MNIST data, and is thereby able to complete a truncated version of the input.

RNNs operate on sequential input. So I convert each MNIST image into a sequence of 28 inputs, each input being a row of 28 pixel values. The RNN state (i.e. the hidden values) is reset to 0 before processing each image, otherwise changing and carrying over values between each row of inputs. The output of the RNN has the same dimension as that of the input—28—and is interpreted as the predicted next row of the input. So, when each MNIST image is processed, we have a sequence of x0…27 rows as input and a sequence of y0…27 rows as output. The cost minimized for training is then the difference between x1…27 and y0…26. I used the cross entropy loss function here.

Input/output image comparison

The above image shows test input and output of the RNN after training. The RNN had 169 nodes (13x13) in a single hidden layer, and used tied weights between the output layer and input layer. i.e. Why is set to Wxh.T. The other params trained are the usual Whh matrix and the input and output bias vectors b and c. Results above are after 16 epochs of training, using the adaptive gradient optimizer implemented in TensorFlow. They look good. But there might be a nagging suspicion that perhaps the network learned something close to the identity function. So let’s give it a real test.

Truncated input/extended output image comparison

The image above shows results when the trained RNN is fed only the first half of test images, and is asked to extrapolate the rest. Extrapolation is by means of feeding the RNN successive rows of predicted values. i.e. xi+1 is set to yi and we proceed until we get to 28 rows. We now have to conclude that the network did learn actual patterns in the data, and is able to complete a half-seen pattern. Of course, it makes mistakes. But when we look at the corresponding test input, we find that there is true ambiguity.

You can get some of the above results by running the command-line below, after installing my comprehend package on github.

$ python test.py --model RNN --epochs 16 --hidden 169 --batch 20 \
                             --mosaic --verbose