In the post about rare words, we discussed the problem that a lot of low-frequency values/tokens are often extremely valuable but cannot be (fully) used for training. To address the problem, we could learn a different feature space for those values that is then fused with the ordinary feature space, or we could use some kind of memory. In case of neural networks, there is an implicit memory that can be trained with a loss function, but there is no easy way to control this memory beyond predicting the correct output.
The idea to add an explicit addressable memory that is attached to the network, like in memory networks [arxiv:1503.08895]. However, a problem of some methods is that the memory is often restricted to so-called episodes that are bounded to the data input sample. In other words, it is not clear how to apply the memory, to process general information that are intended to generalize beyond single training examples. In a recent paper that was submitted to ICLR 2017 [“Learning To Remember Rare Events”] the authors present a method how to augment a network with memory that is especially suited to handle “events” that are rarely present in the data. In other words, the memory is able to store both long-term and short-term patterns.
Since quite some time, we tinkered with the idea to add a memory module to our networks to allow the model to evolve over time and to remember the little things. Especially for preference-based learning, such a component is very valuable since there is very likely a drift or repeating patterns based on temporal attributes, like Christmas, summer or weekends. The problem is that some of those events do not happen very often, but they are still extremely important. Stated differently, the idea is to have a model that is capable of
“life-long” learning were new memories can be formed and older memories are eventually overwritten.
Let’s start with a brief overview of the method. The memory M consists of three parts: The key storage K, combined with the values V and the age A of each entry. Let’s assume that the keys K are positive integers which means we want to query the memory to get the best matching label. The notation looks like that:
K: shape=(size_of_mem, key_size) # list of vectors
V: shape=(size_of_mem, 1) # integer labels
A: shape=(size_of_mem, 1) # age as integers
The query is a vector of key_size dimensions. Like in the paper, each vector is normalized so that the dot product is the cosine similarity [-1,+1].
To retrieve the label that is best matching for a query, we perform a nearest neighbor search:
q = raw_query / np.sqrt(np.sum(raw_query**2))
idx = np.argmax(np.dot(q, K.T))
best_label = V[idx]
To enhance some ordinary neural network with the memory module, we get rid of the softmax output layer and use the representation of the new last layer, usually a fully connected one, as the query to the memory which is l2 normalized. The aim is to map a query to the correct class label of the input sample. As stated in the paper, the whole approach is fully differentiable except for the nearest neighbor search.
At the begin, the memory is empty and the network parameters are just random values. We start by choosing a random sample x and its label y. To convert x into the query, we forward propagate it through the network.
(1) q = fprop(x); q' = l2_norm(q)
Next, we determine the nearest neighbor of the query q’ with respect to the memory.
(2) n1 = nearest_neighbor(q', M)
In case V[n1] already contains the correct label y, we update the memory slot with the average of both values.
(3.a) K[n1] = l2_norm(q' + K[n1])
Otherwise, we need to find a new slot n’ for the query which is usually
done by selecting the oldest entry combined with some randomness.
(3.b) n' = argmax(A + rand); K[n'] = q'; V[n'] = y
The age of all slots is then increased by one, while
(4) A += 1
the age of the updated slot is always reset to zero.
(5.a) A[n1] = 0
(5.b) A[n'] = 0
The last question is how to optimize the network parameters to learn a useful memory for the classification task?
We got a query q’, the correct label y and the nearest neighbor of q’ should return the correct label. Thus, we are computing the k nearest neighbors and find the index n_pos such that V[n_pos] = y and the index n_neg such that V[n_neg] != y. The loss is then defined as:
loss = np.maximum(0, q' * K[n_neg] - q * K[n_pos] + margin)
Stated differently, the cosine similarity f(q’, pos) should be higher than f(q’, neg) by the given margin, which means we want to maximize the similarity for the positive key and minimize it for the negative key.
As a quick reminder, the memory content K is not a trainable parameter, but since the content is derived from the query q we can shape the memory with the used loss function.
To train a good model, one usually needs a sufficiently large training set that provides enough data per tag/label and it is quite common that the performance for rare tags can be pretty bad. With a memory module, a network can encode the limited information that is available and retrieve it later without the risk that the whole pattern is forgotten during training, since it only marginally influences the loss function.
Let’s illustrate the behavior with an example. We start with an empty memory and draw samples from the training set. Thus, a rare tag will be drawn with a very low frequency, but if there is a pattern in the input data, the chance is good that items are mapped to the same memory entry, or at least to a small group of entries. When the training continues, the memory entries for more frequent tags will be updated, but hopefully the memory entries that is assigned to the rare tag won’t be overwritten. The problem is that the age of all non-updated entries is increased per step and thus, entries of rare events might be more likely candidates that will be overwritten. The probability that this will happen both depends on the memory size and the number of distinct patterns in the training data. However, since each epoch visits all training samples, also the rare events get updates eventually and hopefully the pattern is strong enough to avoid that a new memory entry is chosen but the old is updated.
At the end of the training, we hopefully learned a memory that is good at correctly predicting tags for new data, both for the frequent tags, but also for the rare ones for which only limited data was available.
Depending on the actual goal, a static network might suffice. For instance, for some classification tasks, a re-training might be only required 1-2 times a year. However, for other tasks, it is imperative that a network quickly reacts to drifts in input patterns and adjusts the output. But in contrast to on-line learning, the network should not forget old patterns (too) quickly, but remember them for a while because when they occur again, the re-learning process is very inefficient. Therefore, a memory can help to store also several less frequent patterns, like those of rare events, which is essential to explain some preferences and to deliver a good performance.