Deepak Subburam| Deep learning
Research notes

posted by Deepak on Apr 13, 2018

Training RNNs (recurrent neural networks) on long sequences of data is usually done via a technique called truncated backpropagation through time (bptt). Standard backpropagation (backprop), for non-recurrent/sequential neural networks, is when we update (train) the parameters of the model by taking the derivative of the cost function applied to the outputs of the network (for a batch of training data) w.r.t. the parameters (outermost layer first and then the inner layers one-by-one using the chain rule—hence backpropagation), and multiplying it by a small negative step-size (so the new parameters yield a smaller cost). For RNNs, since the output at any point in a sequence is dependent on the prior points in a sequence, backprop doesn’t stop with just the present step but goes all the way back to the start of the sequence (hence bptt). If the sequences are long, we can run the backprop every k steps, setting the initial hidden state of the RNN at what it was before the k steps.

The problem is, while we want to train with the outputs of a whole batch at a time (e.g., 100 training data samples at a time), in general the batch samples are unlikely to all have the same sequence length. The canonical solution to this is to add dummy blank entries to the shorter sequences, padding them so that all sequences in the batch end up being as long as the longest one, and have the cost computed in a way so the blanks don’t matter. See torch.nn.utils.rnn.pack_padded_sequence in PyTorch, for example.

I present a more efficient approach: don’t pad sequences; just dynamically swap out sequences as they end, replacing them in their batch slot with the next training data sample. This is tricky but doable in PyTorch, whose imperative approach makes the triple-nested for loop at the heart of the algorithm possible.

Here is the code:

def train_bptt(decoder, contexts, targets, visible_seed, hidden_seed,
               optimizer, batch_size, bptt, cost_fn=nn.PairwiseDistance()):
    """
    Train a nn.Module object using given optimizer and training data sequences,
    using backpropagation through time (bptt), <bptt> sequence steps per
    optimization step.

    Unlike typical implementations, sequences in each batch are not required
    to have the same length, and no <blank> padding is done to enforce that.
    Shorter sequences are swapped out of the batch dynamically, replaced with
    next sequence in the data. This makes training more efficient.

    Trains one epoch, returns computed loss at each step, as a list.

    decoder:
    nn.Module object that implements a
      .forward(visible, context, hidden)
    method which takes in current visible features, a context vector, and
    current hidden state, and returns the next visible output features.
    i.e. Like a standard RNN which takes in an additional context vector at
    each step.

    contexts, targets:
    lists of sequence data; containing context vectors and target output values.
    contexts[i] must have same length as targets[i].
    i.e. input sequences must have same length as output sequences.

    visible_seed, hidden_seed:
    initial feature vector to jumpstart the sequence, and initial hidden state
    for start of the sequence. Of shape [<features>] and
    [<num_layers>, <n_hidden>] respectively.

    optimizer:
    a torch.optim object, e.g. initialized by torch.optim.Adam(...).

    batch_size, bptt:
    Number of sequences to train concurrently, and number of steps to proceed
    in sequence before calling an optimization step.
    """

    # indices holds sequence index in each slot of the batch.
    indices = np.arange(batch_size)

    # positions holds, for each sequence in the batch, the position where
    # we are currently at.
    positions = [0 for i in range(batch_size)]

    visible = torch.stack([visible_seed for i in range(batch_size)]).unsqueeze(0)
    hiddens = torch.stack([hidden_seed for i in range(batch_size)], 1)

    losses = []
    marker = batch_size
    while marker < len(targets):
        optimizer.zero_grad()
        # The following two lists hold output of the decoder, and
        # the values they are to be costed against.
        predicted, actual = [], []
        for counter in range(bptt):
            inputs = torch.stack([contexts[i][p]
                                  for i, p in zip(indices, positions)])
            outputs, hiddens = decoder(visible, inputs, hiddens)
            predicted.append(outputs[0])
            visible = outputs.clone() # can implement teacher forcing here.
            actual.append(torch.stack([targets[i][p]
                                       for i, p in zip(indices, positions)]))

            for b, index, position in zip(list(range(batch_size)),
                                          indices, positions):
                if len(targets[index]) > position + 1:
                    positions[b] += 1                           
                else: #load next sequence
                    marker += 1
                    # we wrap around to start of dataset, if some long
                    # seqence near end of dataset isn't done yet.
                    indices[b] = marker % len(targets)
                    positions[b] = 0
                    visible[0, b] = visible_seed
                    hiddens[:, b] = hidden_seed

        loss = torch.mean(cost_fn(torch.cat(predicted), torch.cat(actual)))
        loss.backward()
        torch.nn.utils.clip_grad_norm(decoder.parameters(), 1.)
        optimizer.step()
        losses.append(loss.data[0])

        # The following frees up memory, discarding computation graph
        # in between optimization steps.
        visible = visible.detach()
        hiddens = hiddens.detach()

    return losses

To describe in summary the three nested-loops that make up the algorithm: The code calls for an optimization every bptt steps, on the batch_size of data samples being processed. Whenever a data sample runs out (its sequence ended), it gets swapped out for the next one—this gets checked and done every step. Steps proceed until the very last sequence in the training data ends1.

This code is available on one of my github pages. Cheers.


1 So data samples at the start of the training data list do get sampled twice, to keep the batch slots filled until the last sequence is exhausted; shuffle your training data between epochs (which is advisable in any case) to avoid a bias.