PyTorch – Batching With Recurrent Nets

Implementing RNNs in PyTorch works like a charm thanks to the dynamic graph computation. All you have to do is a loop where you feed the input to the network and keep track of the new hidden state. At the end, you can feed the last state (or the average of all) into a new layer which can be a classifier or whatever is required for the loss function at hand. This was the easy part. However, working with RNNs in single-batch mode is incredible inefficient when you need to train a very large dataset. The problem is the sequential nature of the RNNs which does not allow to process input in parallel. With mini batches, we can at least use hardware parallelism to speed up the pipeline and we might get a more stable gradient because we use multiple inputs to estimate it.

What is astonishing is that PyTorch provides functionality to help you with the issue, but there is no tutorial or example code that contains all the steps. Sure, there are blogs and snippets on the web that explain it, but often a stand-alone, fully working, example allows to retrace the whole process more easily. Indeed, once you know all the details it is fairly simple to implement, since the PyTorch team did a very good job to hide all the nasty details from the users.

So, let’s start to describe the actual problem: During training, RNNs deal with sequences of different lengths which is no problem in single batch mode. However, if you want to use batching, you have to use padding to convert all samples to the same length as a first step. This can be done by using an extra “dummy” entry (“padding_idx”) in the nn.Embedding module which is added to each input at the end
until all inputs in the batch have the same length. But that is only the first step, since the RNN must ignore all those padded tokens for each input sequence while deriving the gradient w.r.t to the loss function.

This sounds a bit complicated because we have to fiddle with the computational graph, but kindly, there are helper functions for this to avoid to get your hands dirty. But let us start at the begin. Let us assume that we have an input X = [A, B, C] and the length of each sequence X_len = [4, 2, 8]. First, we need to pad each sequence to get a uniform length which requires to sort the input in decreasing order:

X = [torch.ones(4), torch.ones(2), torch.ones(8)]
X.sort(key=lambda x: x.shape[0], reverse=True)
X_pad = pad_sequence(X, batch_first=True, padding_value=0).long()
tensor([[ 1, 1, 1, 1, 1, 1, 1, 1], [ 1, 1, 1, 1, 0, 0, 0, 0], [ 1, 1, 0, 0, 0, 0, 0, 0]])
X_len = torch.LongTensor(map(lambda x: x.shape[0], X))

The option “batch_first” just ensures that the shape is always (batch, seq, feature).

As we can see, each sequence has now a length of 8 with “0”s as padding whenever required. Since we need the unpadded length of each sequence later, we also calculate X_len. With X_pad we can already perform a lookup in an nn.Embeding module:

emb = nn.Embedding(2, 5, padding_idx=0) #n_vocab, n_dim
X_emb = emb(X_pad)

Now, we are ready to feed the input to the RNN:

# setup network and initialize hidden states to zero
net = nn.GRU(5, 10, batch_first=True) #n_dim, n_units
hidden = torch.zeros(1, X_emb.shape[0], 10)
# pack batch
X_packed = pack_padded_sequence(X_emb, X_len, batch_first=True)
# forward step
out, hidden = net(X_packed, hidden)
# unpack batch
out, _ = pad_packed_sequence(out, batch_first=True)
# retrieve the last hidden state w.r.t to the original length for each sequence
idx = torch.arange(0, len(X_len)).long()
out_final = out[idx, X_len - 1, :]

The required steps can be easily wrapped into some class that hides all the nasty details and allows to get the output of an arbitrary recurrent network for a batch of (text) sequences in a straightforward way.

However, there is a drawback we need to take care of. For example, if we train a classifier and we sample a mini batch and the corresponding labels (X, Y), the procedure described above changes the order of X, while Y remain the same. The problem arises because of the sorting step that is only applied to X which makes the solution obvious: we also have to sort Y, but by X_len to get the identical order. The following code is not very nifty, but it works:

Y = [-1, 1, -1]
Y_ = zip(Y, X_len)
Y = map(lambda x: x[0], sorted(Y_, key=lambda x: x[1], reverse=True))

Bottom line, single-batch use of RNNs is a piece of cake, but the performance neither let you allow to train bigger networks or larger datasets, nor is the inference performance sufficient for real-world use. Despite the fact that the pad/pack/unpack scheme by PyTorch is not very complicated, it still needs some time to get used to it. But once one mastered it, the performance gain is more than noteworthy and allows to use RNNs at a much larger scale.


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s