Flaming Winners

Recently, we read a paper that also mentioned winner-takes-all (WTA) circuits and since we moved from Theano to PyTorch, we wanted to give the new idea a try. This type of neuron is similar to maxout, but instead of reducing the output dimensions, the dimensions are kept but filled with zeros. Thus, a layer consists of a group of neurons and in each group, only the “fittest” survives, while the others are set to zero. For example, let’s assume that we have 128 neurons and they should form 32 groups with 4 units each. In PyTorch this is done with a linear layer: wta = nn.Linear(dim_in, 32*4). Next comes the implementation of the forward step which is straightforward. We assume that the shape of the tensor is (batch_size, dim_in).

def forward(self, input):
 h = wta(input) #projection
 h = h.view(-1, 32, 4) # reshape: (1, 32, 4)
 val, _ = h.max(2) # maximal values per batch
 val = val[:, :, None] # reshape: (batch, 1, 1)
 pre = val * (h >= val).type(torch.FloatTensor) #binary matrix->float matrix
 return pre.view(-1, 32*4) # reshape: (batch, 32*4)

That’s it. Definitely not rocket science, just a bit of juggling with the shape of the tensors and reshaping.


Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s