Tagged: multiprocessing

PyTorch: Tackle Sparse Gradient Issues with Multi-Processing

Let’s imagine that we want to train a model that is using an embedding layer for a very large vocabulary. In on-line mode, you only work on very few words per sample which means you get a sparse gradient because most of the words do not need to be touched in any way[1].

The good news is that it already works with vanilla gradient descent and AdaGrad. However, since the latter eventually decays the learning rate to zero, we might have a problem if we need to visit a lot of samples to achieve a good score. This might be the case for our recent model that is using a triplet loss, because not every update has the same impact and using more recent gradient information would be more efficient. Thus, we decided *not* to use Adagrad.

As a result, we can only use stochastic gradient descent. There is nothing wrong with this optimizer, but it can take lots of time until convergence. We can address parts of the problem with momentum, since it accelerates learning if the gradient follows one direction, but it turned out that enabling momentum turns the sparse update into a dense one and that means we are losing our only computational advantage.

Again, the issue is likely to be also relevant for other frameworks and in general sparse operations always seem a little behind their dense colleagues. Since PyTorch is very young, we don’t mind challenges, as noted before.

We would really like to continue using it for our current experiment, but we also need to ensure that the optimizer is not a bottleneck in terms of training time. The good news is that the results, we got so far, confirm that the model is learning a useful representation. But with the problem of the long tail, we might need more time to perform a sufficiently large number of updates for the rare words.

So, in this particular case we do not have much options, but a good one would suffice and indeed there seems to be a way out it: Asynchronous training with multi processing. We still need to investigate more details, but PyTorch provides a drop-in replacement for “multiprocess” and a quick & dirty example seems to work already.

The idea is to create the model as usual, with the only exception to call “model.share_memory()” to properly share model parameters with fork(). Then, we spawn N new processes that all gets a copy of the model, with tied parameters, but each with its own optimizer and data sampler. In other words, we perform independent N trainings but all processes update the same model parameters. The provided example code from PyTorch is pretty much runnable out of the box:


## borrowed from PyTorch ##
from torch.multiprocessing as mp
model = TheModel()
model.share_memory() # required for async training
# train: function with "model" as the parameter
procs = []
for _ in xrange(4): # no. processes
 p = mp.Process(target=train, args=model(,))
 p.start()
 procs.append(p)

for p in procs: # wait until all trainers are done
 p.join()

The training function ‘train’ does not require any special code, or stated differently it, if you call it just once, single-threaded, it works as usual. As noted before, there are likely some issues that need to be taken care of, but a first test seemed to work without any problems.

Bottom line, the parallel training should help to increase the through-put to perform more updates which -hopefully- leads to an earlier convergence of the model. We still need to investigate the quality of the trained model, along with some hyper-parameters, like the number of processes, but we are confident that we find a way to make it work.

[1]

Advertisements