In-Context Learning: Learning To Add
The ability of LLMs to perform in-context learning (ICL), without any weight updates, to solve a broad range of tasks is quite impressive. But why and when ICL emerges is not well understood. In [1] experiments are conducted to shed some light on those questions. In contrast to ICL, in-weights learning (IWL) means that a neural net just relies on its weights to perform a correct prediction which is the ‘default’ training mode.
The goal of our little experiment is to compare the scores of both methods with a toy task and in our case, it is the addition of two positive numbers. The setup we use is very similar to [2], but we just train a single function. The ‘x’ input is a concatenation of two one-hot vectors to represent x1, x2 and the y output is then a one-hot vector of the result, x1+x2. The input and output dimension is the same: 2*range(x). Causal attention is used avoid cheating.
A training sequence consists of (x, y) example pairs and final (x,) to predict the correct y output. The number of pairs can vary and if no pairs are given, the ICL turns into IWL, since no example is given and only the weights are be used for the prediction.
Note: for such a toy problem, a model with enough capacity can memorize all input to output mappings which means 100% training score, but likely 0% test. We use weight decay to limit the model capacity.
The baseline model is a very small Transformer with just one layer and 32 dims. We provide three examples and one ‘query’ per training tep and the model is trained with RMSprop. The dataset consists of all possible pairs of adding two numbers and 20% is used for testing the model with unknown data. The test score is periodically checked during training, to avoid memorizing / overfit issues.
As expected, even with regularization the model quickly learns to solve the task. With some tuning, the test score for the unknown data eventually reaches ~85% which clearly confirms that the model learnt something beyond memorizing. For the test score, the training set is shuffled for every test input and three pairs are used as examples. Due to sampling, the test score might vary a bit. For the IWL test score, no examples were provided and the prediction solely relies on the input x.
First, we analyzed the errors of the ICL setup. We repeated the training with different seeds, but two examples were always predicted wrong, if they ended up in the test set:x=(min, min) -> y=min x=(max, max) -> y=2*max
For example 0+0=0 and 4+4=8.
What’s so special about those numbers? Well, actually it is not about the numbers, but the frequencies of the pairs.
Too vague? Okay, let’s do some counting. If we enumerate all pairs how to add to numbers from 0 to 30 + 0 = 0
0 + 1 = 1
0 + 2 = 2
1 + 0 = 1
1 + 1 = 2
1 + 2 = 3
2 + 0 = 2
2 + 1 = 3
2 + 2 = 4
it is obvious that the first and the last example is only present once, while other pairs are more frequent. In terms of training it means the network only sees those min/max pairs each once an iteration of the training set. So, it’s not a training issue but a data distribution issue.
With a simple script, a plot can be created to visually inspect the frequencies but the result is not really surprising. The plot forms a ‘triangle’ with the maximum at range(x)/2 and it confirms the intuition that some pairs are very rare and thus are harder to correctly predict, since the network didn’t “see” those very often. A possible solution is to scale the loss according to the frequency which is a common pattern also for other problems.
What about IWL? We set the number of examples to zero and repeated the test. The result was a bit disappointing, since the score was sometimes a bit higher and at never below the ICL score. In other words, for such a toy problem, the extra examples does not seem to improve the score and sometimes even hurt it.
To draw conclusions is a bit hard, since the task is likely too simple which means IWL likely works perfectly and ICL is not needed at all and thus does not ’emerge’. The long-tail issue of the input/output pairs is unlikely to prevent ICL to emerge, but of course affects the learning dynamics of the model.
References
[1] arxiv:2311.08360 “The Transient Nature of Emergent In-Context Learning in Transformers”
[2] arxiv:2310.03016 “Understanding In-Context Learning in Transformers And LLMs By Learning To Learn Discrete Functions”