Recently, we read a paper that also mentioned winner-takes-all (WTA) circuits and since we moved from Theano to PyTorch, we wanted to give the new idea a try. This type of neuron is similar to maxout, but instead of reducing the output dimensions, the dimensions are kept but filled with zeros. Thus, a layer consists of a group of neurons and in each group, only the “fittest” survives, while the others are set to zero. For example, let’s assume that we have 128 neurons and they should form 32 groups with 4 units each. In PyTorch this is done with a linear layer: wta = nn.Linear(dim_in, 32*4). Next comes the implementation of the forward step which is straightforward. We assume that the shape of the tensor is (batch_size, dim_in).
def forward(self, input):
h = wta(input) #projection
h = h.view(-1, 32, 4) # reshape: (1, 32, 4)
val, _ = h.max(2) # maximal values per batch
val = val[:, :, None] # reshape: (batch, 1, 1)
pre = val * (h >= val).type(torch.FloatTensor) #binary matrix->float matrix
return pre.view(-1, 32*4) # reshape: (batch, 32*4)
That’s it. Definitely not rocket science, just a bit of juggling with the shape of the tensors and reshaping.