Tagged: triplet

Let’s Make Some Noise

Sometimes it is a good idea to try a new direction when you are stuck. In other words, we needed some new inspiration and we thought it’s worth to turn to a very different domain, in our case audio. Furthermore, since quite some time we toyed with the idea to tag a specific voice in an audio signal by somehow learning a representation of the speaker, so it felt like the way to go.

A possible scenario looks like this: We record a movie via DVB-S and extract the audio stream. Then we convert the raw audio into a more suitable representation and classify all time frames, or time windows, with our learned model with +1/-1. At the end, we have time markers where the trained voice has been detected: [at min 3.1, at min 37.3, ..]. So far for the theory, now let’s turn to reality.

For us it was settled, that PyTorch is our framework of choice. Thus, as a first step we needed audio support. We hoped that in the spirit of torchvision, there is also torchaudio and we were not disappointed. The “load” function allows us to load arbitrary audio files in raw format and return the data as a tensor. However, this format requires a lot of computational resources, since every second is encoded as rate (e.g. 41,000) float values, per channel. Thus, the shape of the tensor is (rate * seconds, channels), which is huge for a full-length movie.

So we are interested in a more compact representation and as a first step, we converted stereo signals to mono (“transforms.DownmixMono”) which reduces the shape to (rate * seconds, 1). But since this is still a lot of data, we did some research to get an overview of popular transformations and we decided to use MEL spectrograms, also because there is an interface in the torchaudio package (“transforms.MEL”). With default values from papers, and re-sampling to 22,1000 Hz, each second of raw audio is now encoded as a (128, 22) matrix. In this setting, the rows are the frequency axis and the columns are the time axis. We further apply a log transformation on the data to avoid exploding gradients, since the magnitude of the spectrogram data can be very high.

Now the question is how to encode this information into a new representation to model the similarity between frames? There are several approaches possible. For instance, we could train an ordinary classifier one-vs-rest that outputs +1 if the frame is spoken by the speaker or -1 otherwise. But we opted for a triplet-based method to better model local neighborhoods. The drawback is that we cannot directly classify unseen frames, but we need some kind of nearest neighbor lookup to decide if the frame is a positive match. Thus, it makes sense that the positive data from training forms a memory component that in combination with a threshold acts like a classifier.

Next, we need to design our network architecture. With the chosen MEL transformation, we could easily train a feed-forward neural net, the input dim would be just 128*22=2816, but dense layers are not invariant to shifts in frequency[arxiv:1709.04396] and thus, a minor change in the input can lead to a larger change in the feature space. Thus, we decided to follow the steps of the early papers that uses convolution over the time axis to learn a representation which is a 1d convolution. The architecture is heavily inspired by the convnets from vision, with the exception that pooling and convolution just uses one channel, not two.

Thanks to PyTorch we have everything we need and a prototype consists just of a few lines of Python. Here is a sketch of the network:

import torch
from torch.nn import Conv1d
from torch.nn import MaxPool1d
from torch.nn import Linear
from torch.autograd import Variable
from torch.nn import functional as F

x = Variable(torch.randn(1, 128, 22))
c1 = Conv1d(in_channels=128, out_channels=32, kernel_size=3)
c2 = Conv1d(in_channels=32, out_channels=32, kernel_size=3)
m1 = MaxPool1d(2)
l1 = Linear(32, 16, bias=False)
h_2d = c2(m1(c1(x)))
h = F.adaptive_avg_pool2d(h_2d, (32, 1)).squeeze()
out = l1(h)

First, there is a convolution, followed by max-pooling, followed by a convolution and at the end, a global average pooling, that returns the mean of each filter map, followed by an affine transformation that represents the final embedding space. Additional blocks like normalization and non-linear activation functions are omitted for clarity. Such an architecture has a lot of benefits: First, we can stack blocks of conv/norm/relu/pool to form a deep network, second the network has also very few trainable parameters and last but not least, the forward step is computationally very efficient.

The training of the network is also pretty straightforward. The data set consists of spoken audio material by the person to recognize, as positive examples and arbitrary audio from other persons as negative examples. Without a doubt the selection of “the rest” impacts the performance of the network, since if all samples are already sufficiently far away from the speaker samples, no learning is done. This issue requires more research, but even our naive selection of negative samples lead to a solid performance.

Next, all audio files are pre-processed and split into frames of ~2 seconds on which the transformation is applied. The order of the frames is not preserved, since the “classification” works on single frames. A learning step consists of a sampling of an anchor and a positive sample and an arbitrary negative sample. Each input to the network represents a single time frame with the possibility to feed a batch of frames to the network. We l2 normalize all network output and use the cosine similarity to determine the triplet loss:
loss = torch.clamp(margin=0.3 + dot(anchor, negative) - dot(anchor, positive), min=0)
In other words, if the negative sample is sufficiently far away from the anchor (>= margin) no learning is required, otherwise the parameters are adjusted to push the negative sample away from the anchor.

However, it can be challenging to find good negative samples, since at later stages of the training, most samples are already well separated and thus have a loss of zero. This means, we need to find violators, outside the batch, to further improve the model. This can be computationally expensive, since we need to calculate the loss on many samples until we find enough of them. However, the procedure is required to ensure that we learn a good model and that the learning converges.

When the model is trained, the positive samples are fed to the network and the representation is stored as some kind of “memory”. As a baseline, new frames are classified by performing a nearest neighbor lookup (cosine similarity) on the memory and a frame is marked as “positive” if the mean of the top-5 scores from memory are above a threshold, like 0.7. Astonishingly, this baseline is pretty robust and already allows to reliably mark relevant time windows of audio material without too many false positives.

Bottom line, regardless of the domain, the machine learning pipeline stays pretty much the same. We have a problem, data, cleansing, optional a transformation and we need a good network architecture and a proper loss function to learn a good model. The next steps are more experiments to evaluate the model and to come up with a better way to classify unseen data based only on positive examples.

Advertisements

Think Local, Not Global

In contrast to some academic datasets, real-life datasets are often unbalanced, with a lot noise and more variance than some models can cope with. There are ways to adapt models for those problems, but sometimes a classification is too restrictive. What do we mean by that? With a lot of variance for a single label, it might be possible to squeeze all samples into some corner of a high-dimensional space, but it is probably easier to learn the relative similar of samples.

With a triplet loss, we can move two samples with the same label closer together while we push them further away for a sample of a different label: maximum(0, margin + f(anchor,neg) - f(anchor,pos)). The approach has the advantage that we can distribute samples across the whole feature space by forming “local clusters” instead of concentrating it in a dense region of the space. The quality of the model depends on the sampling of the triplets, since for a lot of triplets the margin is already preserved and no learning is done. But since the sampling can be done asynchronously, it is usually no problem to find enough violating triplets to continue the training.

If we take for instance items from an electronic program guide (epg) there might be a huge variance for the category “sports”, but only few words to describe the content. Thus, it might be sufficient that a nearest neighbor search -in the feature space- returns any item with the same label to perform a successful “classification”. This has the benefit that we can capture lots of local patterns with a shallower network instead of one global pattern in case of a softmax classification that requires a network that is likely complex. The price of the model is that we cannot simply classify new items with a single forward pass, but the model is probably easier to train because of the relative similarity. Plus, if the similarity of the anchor and the positive sample is higher than the similarity of the anchor and the negative one, with respect to the margin, there is no loss at all and no gradient step is required. In case of a nll loss, for example, there is always a gradient step, even if it is very small.

Furthermore, nearest neighbors searches[arxiv:1703.03129] can be done with matrix multiplications that are very efficient these days thanks to modern CPU and GPU architectures. It is also not required to scan the whole dataset, because it is likely that we can perform a bottom up clustering of the feature space to handle near duplicates.

The next question is how to combine a triplet loss with attention to aggregate data from related items, for example items with similar tags, to group embeddings of similar words. We are using the ideas from the post to encode items with a variable number of words into a fixed length representation.