With residual networks, batch normalization and sophisticated recurrent net like LSTMs, it seems that getting a network to “fly” is just a matter of data and time. It probably feels like that because most papers do not mention the effort the authors had to take before the network actually started learning. We have no evidence for that but we a hunch that a successful setup can take a lot of time, mostly to figure out the hyper-parameters and to tune the network architecture. But who knows, we could be also totally wrong.
What’s definitely useful is to visualize as much aspects of the learning as possible. This includes to keep track of stats like the gradient norm, magnitudes of activations, the accuracy and the loss, and many other things. One excellent example is TensorBoard that allows to visualize metrics during learning. But sometimes it just helps a lot to visualize the “flow” of the gradient through the network, to see if the training is healthy or not.
In our case, we tried to train a network that contains two LSTMs but despite the mitigation of vanishing gradients that LSTMS should have, there was no real flow in the network and thus, the loss did not decrease as expected. For each learning step, we dumped the norm of the whole gradient but this is not useful to learn something about the flow of it through the network per layer. So, we decided to plot the absolute mean of the gradient for each layer, excluding bias values, to debug which layer is responsible for “blocking” the backward flow of the error signal.
We used the following script:
lay_grads, lay_names = , 
for name, param in net.named_parameters():
if param.requires_grad and 'bias' not in name:
name = name.replace('.weight', '')
mean = float(param.grad.abs().mean())
print(" ++ %s=%.6f" % (name, mean))
plt.bar(range(len(lay_grads)), lay_grads, align='center', alpha=0.3)
plt.hlines(0, 0, len(lay_grads) + 1, linewidth=1, color="k" )
plt.xticks(range(0,len(lay_grads), 1), lay_names, rotation=90)
plt.xlim(left=-1, right=len(lay_grads) + 1)
plt.ylabel("avg grad"); plt.title("gradient flow"); plt.grid(True); plt.tight_layout()
We borrowed ideas from a post in the PyTorch discussion forum with some adjustments. For instance, without the tight layout step, not all text was visible on the x axis and we did not use plot, but the bar style.
It is hard to come up with heuristics for absolute values, but if a bar in the plot is almost zero for several iterations, it’s a good indicator that something is wrong. Since if it’s close to zero, the weights of this particular layer won’t get updated much and thus it is likely that the loss won’t go down, because the gradient is “blocked”.
Despite the fact that the visualization will not offer any concrete solutions, you have at least a clue where to start debugging which can help a lot if you have a 24 layer network.