BahdanauAttention

PyTorch Attention module part of torch.nn of indicLP library.

Getting Started

Bahdanau or Additive Attention layer for pytorch has been implemented, as it often plays a crucial part in NLP.

Example
class BahdanauAttention(torch.nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.V = torch.nn.Parameter(torch.rand(self.decoder_dim))
        self.W1 = torch.nn.Linear(self.decoder_dim, self.decoder_dim)
        self.W2 = torch.nn.Linear(self.encoder_dim, self.decoder_dim)

    def forward(self, query, values):
        weights = self._get_weights(query,values)
        weights = torch.nn.functional.softmax(weights, dim = 0)
        return weights @ values

    def _get_weights(self, query, values):
        query = query.repeat(values.size(0), 1)
        weights = self.W1(query) + self.W2(values)
        return torch.tanh(weights) @ self.V 

Reference Materials

Following are some reference materials for Preprocessing module