With PyTorch it’s pretty easy to implement arbitrary loss functions because of the dynamic computational graph. With it, you can use loops and other Python flow control which is extremely useful if you start to implement a more complex loss function. We follow the principle “readable first, performance later” which means we don’t care if the code is efficient, as long as it is easy to read and of course to easy to debug. When we are sure, after performing extensive tests, that we correctly implemented the loss function, then and only then, we think about optimization. The recipe from Adrej Karpathy gives a good hint that we should overfit on a small data set first, before we start training on the whole dataset. This is very important since if we cannot get the loss to go down on a tiny data set, something is wrong and we need to debug our code.
We recently stumbled on a paper from Hinton[arxiv:1902.01889] with the note that a batched version of the loss is provided for TensorFlow (in the CleverHans package). However, since we are using PyTorch, we started from scratch to learn something about the challenges.
Let’s start with some notation. We have a batch X with the corresponding labels in Y, where X is of shape (n, dim) and Y has a shape of (n) and contains integer labels, while X contains the features from some hidden layer.
For a sample (x, y) the distance between x and some other sample x’ is used to define the probability to sample neighboring samples for x, see chapter 2 in[arxiv:1902.01889]. To turn distances into ‘similarity’ values, a gaussian kernel function is used: exp(-dist(x, x’)**2/T), where T is the temperature which is tunable to stronger consider also larger distances (if T is chosen larger) and dist is the euclidean distance betwen x and x’.
For a sample x_i of label y_i, the loss is calculated as:
same_i = sum(j=1 to n, j != i and y_i == y_j) exp(-dist(x_i - x_j)**2/T)
rest_i = sum(k=1 to n, k != i) exp(-dist(x_i - x_k)**2/T)
prob_i = same_i/rest_i
loss_i = -log(prob_i)
The (j != i) and (k != i) part is also known as “leave one out” which means, we skip the sample to avoid a self reference where exp(-dist(x_i, x_i)**2/T) is 1 because the distance is zero.
The idea is pretty simple: if two samples share the same label, the distance between the pair should be lower than to any other sample with a different label. This is the (y_i == y_j) part in same_i formula which selects all samples with the same label. The negative log step should be familiar from the softmax loss.
Let’s consider two cases before we start coding:
(1) A pair (x_i, x_j) share the same label and is actually pretty close. Thus, exp(-dist(x_i, x_j)**2/T) is close to 1 since dist(x_i, x_j) is close to zero. If other distances (x_i, x_k) are reasonable large, the term exp(-large**2/T) is close to zero. So, same_i looks like [0.99, 0.99] because all entries with the same labels are pretty close. And in case other entries are well separated, rest_i looks like [0.99, 0.01, 0.99, 0.01, 0.01] which means 0.99 are entries with the same label and all other entries have different labels. If we sum this up, and divide it, we get a probability ~0.98 and thus a loss of ~0.015 as expected.
(2) Now, if a pair (x_i, x_j) is further away and a sample with a different label x_k is closer to x_i, same_i now looks like [0.01, 0.99] since one entry has a larger distance. And since an impostor now has a smaller distance, rest_i looks like [0.99, 0.01, 0.01, 0.99, 0.01]. The values for rest_i are the same, because only two entries switched places, but the sum of same_i is different now. This means, the probability is now only ~0.5 and the loss is ~0.70.
Bottom line, the loss is supposed to be low, if all samples with the same label are densely located in the feature space, or at least, each pair with the same label is closer to to any other sample with a different label. This is the soft version of the classical k nearest neighbor method.
What follows is a naive implementation of the loss:
for i in range(batch_x.shape):
label = batch_y[i]
batch_same = batch_x[(batch_y == label) & (idx != i)]
batch_rest = batch_x[idx != i]
dists_1 = ((batch_x[i] - batch_same)**2).sum(dim=1)
dists_2 = ((batch_x[i] - batch_rest)**2).sum(dim=1)
prob = torch.exp(-dists_1/temp).sum() / torch.exp(-dists_2/temp).sum()
assert float(prob.data) <= 1
loss = -torch.log(prob)
loss /= batch_x.shape
That’s not scary at all thanks to some nice features in PyTorch. First, with masks, it is straightforward to filter entries in the batch and since we can also use the logical AND operation, it is possible to combine selection filters. For batch_same, we select all entries with the same label, except entry i, while for batch_rest everything except entry i is selected. The calculation of the squared distances is also easy. Then we turn the distances into ‘similarities’ sum them up as in the example and divide them to get the final probability to sample an entry with the same label. At the end, we average the loss and we are done.
No doubt that the code can be optimized, but what we wanted to show that with PyTorch the implementation of complex loss functions, one that use advanced indexing for example, is straightforward. To fully utilize GPU devices or to implement efficient batching is a different story we tell later.