How do we go about understanding what each hidden unit in a RNN is capturing? We are all familiar with the visualization of the first hidden layer weights of a vanilla feed-forward network, such as the header image of this blog. Such a visualization shows which feature of an input image activates the hidden unit. Here I present an analogous procedure for RNN weights.
For the RNN case, we seek to understand what sequence of inputs drives a hidden unit to activate. Say we have the RNN described in my earlier post,
with sequential rows of MNIST images as inputs, each input being a row of 28 pixel values (and the output is the predicted next row of 28 pixel values). For each hidden unit, can we figure out the sequence of L rows, for some L, say
L = 10, that best activate it?
The space of all possible sequences of 10 input rows is quite large. Furthermore, we are concerned mainly with sequences that appear in our data, i.e. the input distribution. Therefore, it makes sense to just sample a sufficient number of L row sequences from the input data itself, and see which of those activate the hidden unit in question. For our sequential MNIST RNN example, we can generate these samples by taking each MNIST image in the training set, and select L sequential rows starting from a random row index between
28 - L.
Now we can construct a feature for each hidden unit in a number of ways. We can simply find the single sequence from the samples generated as described above that maximally activates the hidden unit, and display that as the visualization for the hidden unit. This approach has the problem that such a maximal sequence may have multiple features, some which are not relevant to the hidden unit. Another approach is to take a weighted average of all sequences, with the weights being the activation of the hidden unit after the corresponding sequence. Then the irrelevant features will get diluted out. This is the approach I implemented. I used the pre-sigmoid activation, floored at
0.0, as weights, to get a sharper average. See the figure below, for
L = 10.
There are 169 cells in the grid above, one for each hidden unit. In each cell, the top half is the pre-sigmoid-activation-weighted average
L = 10 sequence; in other words, the 10-row height image the hidden unit likes to see. The bottom half shows the subsequent 10 rows that result when the RNN is rolled-forward after the input sequence, being fed the output it predicts after each row. This indicates what the RNN expects to see after the input image the hidden unit likes to see.
We can see that many of the hidden units activate after seeing the top half of the digit 0, and go on to predict the bottom half correctly. 9, 8, 7, 3 and 2 are well represented as well. The digit 1 appears to be missing; perhaps because the parameter L is not discriminating for it, and that affects results somehow. That leaves the digits 4, 5 and 6; suggesting that no single hidden unit by itself has captured the structure in any of these digits.
This visualization has now been implemented in my
comprehend package on github. Comments welcome.