Linear Attention: Different Kernels and Feature Representations

In a recently proposed method [1], self-attention is rephrased as a kernelized similarity function. Despite the fact that the computational complexity is drastically reduced, it is now possible to use other feature transformations than exp(q*k). The paper uses elu(x)+1 for the transformation because by definition a kernel function k(x,y) must be positive as it is a proxy for the similarity between x and y.

What is really nice is that the authors also provide PyTorch code that is pretty much self contained [2] and thus it easily allows to test different feature maps. But before we actually start using the new layer in a network, we wanted to better understand the impact of the feature transformation. This can be done without any neural net training, just by looking at same numbers.

We generate five random numbers N(0, 1): X = [0.56 0.12 0.06 -0.43 0.03]. We can imagine that X is the inner product of some word with all other words in a sequence. Thus, higher positive scores indicate similarity, while negative values indicate the opposite. Then we used four different transformations:
(1) torch.nn.functional.leaky_relu(X)+0.05)**2 [leaky relu]
(2) torch.nn.functional.elu(X)+1 [elu]
(3) torch.nn.functional.gelu(X)+0.2 [gelu]

Then we applied the feature transformation on X and normalized the scores:

[0.56 0.12 0.06 -0.43 0.03] -- X --
[0.88 0.07 0.03 0.005 0.016] -- leaky_relu
[0.288 0.207 0.196 0.12 0.19 ] -- elu
[0.437 0.194 0.169 0.041 0.158] -- gelu

The intuition is that larger positive values should contribute more while negative ones should get little or almost no weight. The problem with (2)-(3) is that due to the shifting values in the range [-bias, 0] those values still contribute a lot to the final weights. This can be seen by (3) with the input values -0.43 / 0.03 which are mapped to 0.12 / 0.19 despite the fact that the difference between them is very noticeable. In general, the elu transformation leads a distribution that resembles uniformity which is not very useful since words usually do not contribute equally to a context. With gelu at least the problem with negative values is handled. The first transformation probably best reflects the assumption that negative and very small inner product scores should not contribute (much) to the attention, maybe it puts a little too much emphasis on larger positive values.

Bottom line, the experiments in the paper [1] confirm that the ELU transformation is a reasonable choice, but in our own experiments we noticed some strange behavior for the final representation when transformation (2) is used and thus we stick with (1). However, it would be interesting to know why this happens as this probably leads to a deeper understanding of the transformation step and its implications.

[1] 2006.16236: Fast Autoregressive Transformers with Linear Attention

BERT: The Fallback For Everything?

Without a doubt BERT and variations of it are extremely powerful and successful for natural language processing -NLP- problems. And thanks to its demanding resource requirements, the research is continuing and lots of new ideas are explored. But the theoretical perspective [1,2] and the derived insights are also very valuable to allow a look into the blackbox and why the approach is so successful. A key ingredient is definitely the self attention mechanism which was analyzed and modified a lot after its proposal mid 2017.

After BERT gained traction, more and more methods and projects used it as a building block to solve various problems. The recipe is like that: Rephrase your problem so you can feed the input data into a BERT encoder, aggregate the output is required and train with your own data. In other words, why come up with your own network architecture when BERT is already a good baseline? The method RepBERT [3] is a perfect example for it. The text input is tokenized and fed into BERT f(X): [CLS] X [SEP], the output is averaged to get a fixed-length vector and a off-the-shelf max-margin loss is used to score pairs in the batch. And there are many other examples, but this one is particularly easy to describe. Furthermore, the perspective to be state-of-the-art (SOTA) with just your data and BERT as the workhorse is of course tempting.

Bottom line, like everybody enhanced ConvNets with BatchNorm years ago, it is a noticeable trend to try BERT on lots of existing problems to see if it improves over existing methods, maybe even become SOTA with it. There is nothing wrong with saving creativity and energy by using a pretrained model to solve problems and this also helps non-expert users to train and deploy ML solutions more easier. But the question is if the mantra ‘BERT can do it, so no need to come up with something new’ sticks, does hinder alternative research and new directions? We frankly don’t know, but we are a little tired of the trend that it sometimes feels like coming up with the funny name for the method probably took longer than the actual write-up.

[1] D19-1443: Unified Understanding for Transformer’s Attention via the Lens of Kernel
[2] 1911.03584: On the Relationship between Self-Attention and Convolutional Layers
[3] 2006.15498: RepBERT – Contextualized Word Embeddings For First-State Retrieval

Eureka: But Are We Done Here?

Especially in the world of neural nets, patience is a virtue that is definitely required. Very recently we saw a lecture series about unsupervised / contrastive learning and they reminded us again that this is far from being solved, neither for the domain of vision nor for natural language processing. What makes it so frustrating is that you cannot directly see the progress during training. In case of supervised learning, you see the loss is going down and the accuracy is increasing and even for reinforcement learning, after being patient, you can see the agent -hopefully- navigating through a maze, or driving a car or whatever. But with unsupervised learning, you need another task where the output is used to solve a different problem. Like classification with a linear classifier in combination with the features learned by some contrastive learning method. If this works, excellent, otherwise, probably a painful debugging starts to analyze the feature space and it is possible that a needle-in-the-haystack run begins. Again, the problem are missing general metrics that allow to call a moment a true Eureka moment, instead of just knowing that the loss is going down and therefore it seems to work.

For example, we trained dozens of networks with different loss functions, including [1] and the progress during training gives us confidence that the representation also generalizes to other tasks. But at the end, each and every loss function, despite low values and other heuristics that indicated that it worked, learned representations that always makes silly mistakes that seem totally avoidable. And not to mention the ‘one modification changes everything’-problem. Like in the good old times where a tiny adjust of the weight initialization made the difference between learning and no learning at all. We also recently had such a moment where gradient clipping was required with a very low norm to made progress at all, but after a minor change in the architecture the threshold needed to be evaluated again and the new threshold was magnitudes larger.

We can only guess what people do to avoid that frustration wins. Probably a good team with creative minds is helpful to discuss things and also to think out of the box is likely helpful. It’s like what Schmidhuber said in an AMA about promising students: To continually make progress, no matter how small: “[..]run into a dead end, and backtrack. Another dead end, another backtrack. But they don’t give up.”


The Uniformity Loss As A Regularizer

Especially for unsupervised learning it is often essential that representations are distinct from each other. Even if we consider clustering, it is often not desirable that all samples collapse into a single point. Why? For clustering, the inter-class distance is maximized, but if all samples collapse onto the cluster center, there is no differentiation at the intra-class level possible since everything encoded exactly the same, up to some eps difference. In case we just want to separate classes, this is no problem, but if we want to use the representation for retrieval, the feature space is useless. For margin-like losses this is often addressed by stopping to collapse positive samples if the distance is below a threshold, but this does not guarantee that samples are spread across a region of the manifold.

In [1] we experimented with the soft nearest neighbor loss [arxiv:1902.01889] that is a good example for such clustering. The goal is to separate jointly a set of classes and often as a side effect, dense cluster regions are learned after training. Again, this is beneficial if we just want a classifier but not likely optimal for downstream tasks. In the original paper the authors demonstrated that maximizing the loss leads to a spread of the features across the learned space, but with the uniformity loss [arxiv:2005.10242] we don’t actually need this, because the uniformity loss is doing exactly this. To demonstrate the benefits, we conducted a simple experiment: We used a dataset with labels and a model that uses two losses which are also both minimized:

(1) soft nearest neighbor to cluster related content
(2) uniformity loss to spread the feature across the whole space

The output of the network is l2 normalized which means all embeddings live onto the hypersphere which is required because no uniform loss in unbounded feature spaces is possible. To verify the effectiveness of (2) we disabled it and studied the tSNE plot several time and it always represented each cluster as fairly dense region. When (2) is enabled, each cluster content is more spread with a larger average distance to the center.

The downstream task we considered was a simple retrieval of the nearest neighbors and checking manually if the pairs semantically are matching together. We also checked that the label of the neighbors as a metric for accuracy. Since the task was fairly simple, there were very few errors, but the semantic relation differed a lot depending on (2) was activated or not. Without it, as expected, a lot of samples collapsed and no clear distinction was possible.

Bottom line, the uniformity loss is very general and even if it acts as a opposing force to most losses, it helps to regularize the representation by separating intra- and inter-class distances which is very important for fine-grained classification or retrieval.


Pondering About Network Architectures, Data and Objectives

In case of unsupervised learning, the juicy part is a loss function that allows you to learn regularities from your data. For instance, to predict the next word, or a masked token, if the order is correct or if something is related or not. Depending on your goal, this loss function might be even more important than your actual neural net, at least if the baseline is powerful enough to capture the most important patterns of the dataset.

To better understand the relations between architecture, data and loss function, we decided to create a dataset that easily allows to separate content into classes, without ambiguities and we further reduced the content to specific POS tags, namely (proper) nouns. The goal is to learn contextualized embeddings with no focus on the order of the words. In other words, something like a topic model but by only using local information but with attention.

Since we know that word2vec is able to learn to ‘cluster’ related context, with the restriction that there is only one context, we started with a cbow style loss. The idea is to use a more powerful encoder, like [1], and then to relate a local window around a center word to relate words. We also used the skip-gram loss, the classical next word prediction, but also to predict masked tokens. The training was always done with a uniform + alignment setting [2] to force the embeddings to cover the whole feature space and that each pair of embeddings should be distinct.

It is not obvious, at least not to us, if all the loss functions lead to equally powerful representations. The problem is that, for instance, the classical ‘next word’ objective only considers the right context but it does not look back. However, we compared the results and all models made a lot of silly mistakes, regardless of the loss function. We tuned the encoder to provide more computational power and also replaced it with more powerful encoder, but the results were always the same.

Now the question is what is the problem? The data clearly allows to separate the content and the encoder should be also powerful enough, so what’s missing? Maybe depth? It is well known that good representations require more layers to combine aspects int a more abstract representation. But this particular problem is not really challenging and if a word2vec is able to learn useful relations, our network should not struggle. So, do we need more data? For a supervised model it definitely suffices, but such a model stops learning immediately after the loss is down to ‘zero’ and thus the embeddings are likely not universal. Maybe it is the loss function? Well, that could be, but why does it work with a different setup? Frankly, we have no answer yet.

Let’s do a quick recap: We train to train a not-so-deep neural net to provide contextualized embeddings with a dataset that is fairly manageable. The modules are pretty standard, with no homemade layers and the loss function is also off the shelf. The results are far from being useless, but there are too many silly mistakes that prevents to use the encoder in a real application. What follows is something that each researcher probably knows, the long way of ‘debugging’ and probably we all would love to have a complete checklist or a recipe that we can follow.

The list of questions include:
– do we need more layers / units?
– do we need (additional) regularizers?
– is there enough data?
– is the objective challenging enough?
– is it possible at all to solve the problem (perfectly)?

Bottom line, if different loss functions and network architectures always lead to very similar results, you start wondering if there is a fundamental problem that needs different thinking to be solved. And we don’t want to get started on bugs in the software. Again, we wonder how much time and effort other teams spent before their approach reliably works. In contrast to other engineering, for us, neural nets, still feel like a combination of art and science.


Finding Metrics For Unsupervised Learning

Compared to supervised learning, the are much fewer metrics to apply during training for the evaluation of the representation. In case of classification the accuracy score on the test set is usually a good indicator, maybe in combination with a confusion matrix to identify problematic pairs of labels that often get confused. There are a few problems: First, unsupervised learning often tries to learn the regularities of the input data and there is no automatic metric to capture this, second without any labels visualization of the results can be only done in a strictly ‘geometric’ way. But in the latter case, further analysis is required to take a look into identified groups. This is nothing new, but in case of deep neural nets, where the training might take hours or even days, periodic inspection of snapshots is required to avoid a waste of resources that includes energy. But how to do that?

We are still tinkering with the alignment + uniform loss from [1] and of course we plot the different loss values and also other heuristics, but what does it mean? The first problem is that loss values are for mini-batches and an evaluation of the whole dataset might be costly or even impossible. For instance, even with a small dataset, there are N=15,952 contextualized word embeddings and that means K=N * (N-1) / 2 pairs which equals more than 120 million and that is costly even without the bprop step. But at the end, the goal is to estimate the spread onto the hypersphere and the actual loss values are not always useful for that.

Let us take the following loss values: uniform=-2.614. Reversing the formular loss_uniform=exp(-2*dist).mean().log() tells us that the average distance is about 1.3 which does not seems bad since an optimal value would be 2 (max cosine distance). However, the average can be misleading so let’s check a sample. For that, we randomly sampled an item and determined its neighborhood: Here we consider the squared distance(!). The mean distance is about 2.0 and the maximal distance is 3 which is not bad, but if we consider how many ‘collasped’ items are nearby (with a distance below a threshold ~0.1) the count is 340 and the pattern can be seen if we repeat the experiment. Bottom line, despite the low loss values and the average spread in the learned feature space, we only get an average assessment which is not really useful.

The evaluation of a test set might help to see a gap between the two distributions but the early stopping strategy for unsupervised learning is not reliable since we cannot be sure if the network has converged yet and if we stop when the gap between train and test starts to grow, the features learned so far might not be useful. At the end, we might need to come up with customized metrics for the problem at hand, or we cheat by using labels which are not used by the training, just for the evaluation. Nothing of this is satisfying to get an indicator to know ‘if we are done here’, or if we need to train the net (a little) longer. So, all we can do is to conduct a manual inspection at fixed intervals to actually understand what the net has learned so far.


How To Learn Word Labels

In the last years, unsupervised learning got a lot of attention, especially for NLP. However, the problem is far from being solved. Models like Transformer provide excellent performance but the price is the required data and the training time. Furthermore, despite the success, we are skeptical that any existing method will ever really ‘understand’ the data without an inductive bias. As noted before, with our limited resources, we won’t even try, so we are happy with getting new insights by combining existing methods with crazy ideas.

The current path we follow is a variant of contrastive learning [1] to learn contextualized embeddings. But despite the promising results with our data, we realize that something is missing, some prior knowledge or an inductive bias. To explicitly inject prior knowledge is possible but then the representation is likely limited because of exactly this bias. But what about a weak bias that still allows to learn a generic representation?

To be a little less cryptic: We assume that the meaning of a word depends on the context which is reasonable and furthermore that a word can be assigned with a label to define its type. The type can be some kind of POS tagging, but of course also any other latent ID. It should be noted that this ID is not given, but learned during the training. Details follow soon. In other words, depending on the data, the result might be a set of (overlapping) topics, word types, syntactic properties like upper and lower case, but almost everything is possible, like the tense of the word if possible, or a mixture of different concepts.

Let us discuss some details. We want to learn a set of prototypes to assign labels to embeddings. But in contrast to clustering or topic models, we treat the problem as learning a categorical distribution with backprop(!) for each (contextualized) word. There are different ways to achieve this, we use the gumbel softmax:

Let E be the embedding of a sentence E=(w1,..,wn) where wi is the embedding for the word wi and the length of the sentence is n. A simple way to provide some context is a BiLSTM: H=lstm(E), where H is of shape (n, n_lstm). For the concepts we need to weight matrices: U, V, where U is of shape (n_lstm, n_labels) and V is of shape (n_labels, n_lstm). In PyTorch this goes like that:

H, _ = rnn(E) # (n, n_lstm)
scores = U(H) # (n, n_labels)
scores = gumbel_softmax(scores, hard=True, dim=-1) # (n, n_labels)
M = self.V(scores) # (n, n_lstm)
out =, M), dim=1) # (n, 2*n_lstm)

[At test time, the actual labels can be determine with ‘labels = scores.argmax(dim=-1)’ and no sampling is done]

In this code snippet scores returns a one-hot encoding of the decision and thus, the matrix multiplication with V just selects the right prototype which is then concatenated with the hidden state. At the end, each hidden state consists of the LSTM output and a prototype of the chosen ID. We can reduce the dim to n_lstm if we apply a projection with a matrix W of shape (2*n_lstm, n_lstm).

The idea of this rather minor modification is that the context becomes more explicit if we can assign an integer label which allows an easier interpretation. If this works out, ambivalent words should have different labels depending on the context, while clear-cut words likely always have the same label. But again, it is also possible that the ID is releated to topics or other properties and with mixing, we might loose the interpretable again. Nevertheless, the overhead is neglectable and so we have nothing to loose but a lot to win.


The Best of Three Worlds: Uniformity + Alignent + Contrastive Learning

With the modest computational power at hand, we won’t even try for a big shot model that solves almost every NLP problem. Instead we are happy with a model that allows us to interact with our data in a very simple query-like interface. In other words, it would suffice to derive a simple sentence encoding for retrieval tasks and a contextualized word embedding to find and group related content. The transformer architecture is very promising for both tasks, but it is very hungry when it comes to data and CPU-time. Because of that we still follow a different path which requires much less computational power and also data.

The idea can be roughly summarizes as follows: Use a standard component, like a GRU or [1] to encode a sequence of words to get some context and then use a contrastive learning method to train an embedding. Of course the details are important, because if such a system would always be a good baseline it would be more widely in use. The idea is neither new nor really innovative and has been tried before, like in [2]. But the problem is that the vocabulary is always fixed and even with byte pair encoding (BPE), this limitation is a showstopper for us. This is why we use a dynamic word embedding based on 1D convolutions which allows to embed arbitrary words into a embedding representation.

We played a lot with the combo alignment + uniform [3], but never got satisfying results without providing at least a bit supervision to the model. Recently we stumbled about graph neural nets again and there an LSTM [4] is used as an aggregator to embed the neighborhood of a node. The interesting part is that the input is shuffled since the neighboring nodes have no order. In terms of a sentence, the idea is to select a window, left + right, around each word and also to shuffle words before they are feed to the GRU as some kind of regularization. We treat the result is an alternative view of a sentence. The shape of the result is the same as the vanilla GRU embedding which can been seen as follows:

For each word w_i, we take the word embedding w_i and the words left and right w_i-1, w_i-2 and w_i+1, w_i+2 and store it in x = (w_i, w_i-1, w_i-2, w_i+1, w_i-2). Then x is shuffled and fed into the GRU which produces h with a shape of (1, dim_of_gru) if we take the last state. The procedure is repeated for every word and thus, the final shape of H is (n_words, dim_of_gru) which equals the vanilla GRU. In the paper there is also a combine step but this does not change the final shape of the output. We call this view y, the augmented version of x. With the combo align + uniform, the training can be done straightforwardly without the necessity to draw negative samples.

However, in our case this does not suffice to train an embedding that actually learned a useful context. This is why we enhanced the loss with a contrasive loss in combination with a cbow-style encoding of the output x. To be more concrete, we derive a third view u which is the average of the left + right words given a center word. Exactly like in word2vec with the CBOW training objective: x=(x1, x2, x3) then u=(x1+x3)/2.

Let’s do a quick recap: Let W be a single sentence n words: W=(w1, .., wn). Then

x = ffn(vanilla_gru(W, bidirectional=True))
y = ffn(window_shuffle_gru(W))
u = ffn(window_cbow(x))

where the shape of (x,y,u) equals (n, dim). The ffn-encoder projects all representations onto the unit hypersphere. The final loss is then:

loss_align = ((x - y)**2).sum(1)
loss_nce = contrastive_loss(x, u)
loss_uniform = 0.5 * exp(-dist_x*2).mean().log() + 0.5 * exp(-dist_y*2).mean().log()
loss = loss_align + loss_nce + loss_uniform

In words: align is responsible for putting the views close to each other which means the representation need to learn regularities that are independent of the view. The contrastive part makes the occurrence of the window and the center word more plausible than combined with an arbitrary center word and finally, the uniform part ensures that the representation is spread along the whole hypersphere.

Bottom line: What we have seen dozens of times is that parts of the representation space collapses which means the word embeddings are not distinct any longer. This is unintentional even for related words because words should be always clearly distinguishable. Furthermore, even if a training converged we have seen a lot of impostors as nearest neighbors for a lot of words. With this combined loss, the issue could not be fully solved, but with a smaller data set we could visually, with tSNE, confirm that the representations were spread without collapsing regions and that related contexts are also better grouped together.

[2] “An efficient framework for learning sentence representations”
[4] “Zero-Shot Learning with Common Sense Knowledge Graphs”

For Attention Being Mediocre is Okay

To learn contextualized embeddings, using a BiLSTM is a safe choice. I won’t work perfectly, but with the left-right, right-left combination it provides a good baseline. However, the sequential nature of RNNs is a problem since the training cannot be easily parallelized. A transformer might be the rescue, but the price is the computational overhead. So, is there an alternative? With 1D convolution it is possible to encode a local context, but in contrast to RNNs there is no state that allows to summarize the whole sequence.

Quick recap: it would be great if we could derive a context with a simple matrix multiplication with a complexity lower than O(n**2), where n is the length of the sequence. But in case the problem would be easy, it would have been solved already. On the other hand, even tiny steps towards the goal are welcome, so we did a quick research with a focus on lightweight alternatives.

We found [1] which is not a drop-in replacement since the method is used for the decoder part of a transformer network but this is not set in stone. The idea is elegant and simple which is why we favor it for our small resource research environment. What follows is a short summary: We focus on the core part, since highway layers are widely used in practice and need no further explanation. In section 3 of [1]
the average attention network (AAN) is introduced. The method uses the feed-forward network from the transformer which is abbreviated as ffn. Let y = (y1, .., ym) the input to an arbitrary hidden layer. Then a cumulative
average function is used to contextualize each embedding yi: g_j = ffn(1/j * sum(k=1, end=j, yk))

The idea is to provide more and more context for an input embedding yi. Let m be 4, then the ‘mask’ (Figure 3 in [1]) is (y1, 0.5*(y1+y2), 0.3*(y1+y2+y3), 0.2*(y1+y2+y3,y+4)). The idea is best understood by taking a look at the figure, but let’s try with an informal summary: In contrast to self-attention, there is no left attention which means y1 spreads the info to all embeddings, but does not get itself any info. For y2, it only gets information from y1 and spreads the info to the left embeddings y2,y3,y4 but not to y1 and so forth.

The procedure can be efficiently encoded into a mask matrix which allows batching, since the the steps can be condensed into a single matrix multiplication: X = M*Y, were M is the mask matrix and Y the input. The shape of M is (m, m) and Y is of shape (m, dim).

We could easily modify the mask matrix directly, but we opt for a simpler solution. We use torch.flip(tensor, 0, 1) to mirror the matrix which allows use to emulate a a left-right + right-left processing similar to a bi-directional RNN.

Then the full operation in PyTorch-style:

mask_l = create_mask_matrix(input.shape[1]) # Eq.1
mask_r = mask_l.flip(1) # (a, b, c, d) -> (d, c, b, a)

x_l = torch.matmul(mask_l.unsqueeze(0), input) # left->right
x_r = torch.matmul(mask_r.unsqueeze(0), input) # right->left
x = ffn(x_l + x_r)

With mask_r, y1 spreads no info but gets info from all right embeddings in the sequence, while for mask_l, y1 spreads all info to all left embeddings but gets no info from them. The flow can be optimized but our goal is clarity, not performance right now.

A quick note, we use the idea in a encoder-only network to replace a RNN or self-attention block with the AAN module. With the mask matrix, it just contains of two matrix multiplications instead of passing the whole sequence to an RNN, or to perform a O(n**2) operation with self-attention.

Bottom line, using the average for the attention weights might not be optimal, but it is more expressive than a simple bag-of-word approach of the raw embeddings and much faster than a full-fledged transformer. And as demonstrated by experiments, a large transformer might not be required to solve certain tasks. Thus, it is a meet in the middle solution that combines expressiveness with very few computational overhead.

Update 2020-06-24: The flip operation only worked for vectors, not matrices which is fixed now.

[1] P18-1166: “Accelerating Neural Transformer via an Average Attention Network”

Contrastive Learning: Proud To Be Single

Recently we discussed different aspects of contrastive learning (CL) and one big advantage is that you just need a notation of a ‘positive pair’ to learn a feature representation. With the alignment and uniformity loss [1] it is particularly easy to train arbitrary models with your data. However, even if data augmentation is often sufficient to create such a pair (x,y), wouldn’t it be nice if we could just use a single instance for the training? It sounds a bit odd, but there is an method that uses exactly this approach [2].

The intuition of the approach is both reasonable and simple at the same time: A supervised model learns class templates in the final softmax layer by trying to summarize all relevant features of instances with the same label. This works often great if at test time, an instance is from one of those labels but might fail, if the instance is from a new concept. Then it is tried to combine features from several labels but this also likely fails, because the template captures only the high-level concepts.

To address this issue, we assume that even if a pair (x,y) of samples is from the same category, both x and y are likely be distinct and thus, also the feature representation should differ, at least a bit. In other words, we want to learn to discriminate x not only against y, but against all samples in the data set. Then, the model is forced to learn all the fine-grained details that allow to differ between instances. But since the model capacity is often limited, it is reasonable to assume that the model needs to reuse some features to derive the final representation.

To achieve this, a memory bank is used that is nothing more than a memory that stores for each sample in the data set its representation from the last iteration. In terms of PyTorch, we use x.detach() to indicate that backprop should not be done through the memory bank, but just for the current batch of samples. What follows is a sketch of the method:
N: number of samples in the dataset
I: distinct identifier for each sample for the mem lookup

M = unit_norm(torch.randn(N, dim)) # rand init + l2 norm
I, X = sample_batch()
H = net(H)

# determine loss
loss = 0
for i, h in zip(I, H):
 log_scores = torch.log_softmax(torch.matmul(h.view(1, -1), M.t()), dim=1)
 loss += -log_scores[0, i]
loss /= len(I)

# update memory bank
for i, h in zip(I, H):
 M[i] = unit_norm(M[i] + h.detach())

In words: the memory is initialized with L2 normalized random values and there is a slot for every distinct sample in the dataset: |dataset|=|memory|. Then, we sample each randomized batch from the dataset, along with the row index for each sample to allow a memory lookup. The loss is pretty standard, except that we are now using the new representation h/H and the old representation M. The label is the row index of each sample which means the learning step should maximize the correlation between h and M[i] and ‘decorrelate’ with all other memory slots. That is it and for such a simple method, it works astonishingly well. Depending on the ultimate goal to achieve, a weighted kNN classifier could be used as described by [2]. In case of a new dataset, one could easily create a new memory bank to derive a classifier for a new set of labels.

Let us do a quick recap: there are some old friends, like the unit sphere to learn the representation, along with the loss function, negative sampling for larger datasets and not to forget the memory bank. As an addition, we further regularized the model with the uniformity loss that forces ‘equal distribution’ of features onto the unit sphere.

As usual, the challenge for us is to adapt the method for NLP tasks and thus textual data. We used a simple 3-class dataset and trained a representation without labels to study if the learned features allow to recover the class labels. Since we used a very simple BiLSTM-based model, the final results were mixed. Whenever a label describes one or more concepts that are very fine-grained (distinct), a (linear) separation was visible in t-SNE plots, but for more complex labels there was a visible overlap and an inspection of the next neighbors confirmed that the model was not able to learn all required concepts to linearly separate the data. Thus, the plan is to investigate the results with more sophisticated network architectures.

[2] “Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination”