Share on facebook
Share on twitter
Share on linkedin

Additive attention in PyTorch – Implementation

Attention mechanisms revolutionized machine learning in applications ranging from NLP through computer vision to reinforcement learning. Attention is the key innovation behind the recent success of Transformer-based language models such as BERT.1 In this blog post, I will look at a first instance of attention that sparked the revolution – additive attention (also known as Bahdanau attention) proposed by Bahdanau et al.2

The idea of attention is quite simple: it boils down to weighted averaging. Let us consider machine translation as an example. When generating a translation of a source text, we first pass the source text through an encoder (an LSTM or an equivalent model) to obtain a sequence of encoder hidden states \mathbf{s}_1, \dots, \mathbf{s}_n. Then, at each step of generating a translation (decoding), we selectively attend to these encoder hidden states, that is, we construct a context vector \mathbf{c}_i that is a weighted average of encoder hidden states:

    \[\mathbf{c}_i = \sum\limits_j a_{ij}\mathbf{s}_j\]

We choose the weights a_{ij} based both on encoder hidden states \mathbf{s}_1, \dots, \mathbf{s}_n and decoder hidden states \mathbf{h}_1, \dots, \mathbf{h}_m and normalize them so that they encode a categorical probability distribution p(\mathbf{s}_j|\mathbf{h}_i).

    \[\mathbf{a}_i = \text{softmax}(f_{att}(\mathbf{h}_i, \mathbf{s}_j))\]

Intuitively, this corresponds to assigning each word of a source sentence (encoded as \mathbf{s}_j) a weight a_{ij} that tells how much the word encoded by \mathbf{s}_j is relevant for generating subsequent i-th word (based on \mathbf{h}_i) of a translation. The weighting function f_{att}(\mathbf{h}_i, \mathbf{s}_j) (also known as alignment function or score function) is responsible for this credit assignment.

There are many possible implementations of f_{att}, including multiplicative (Luong) attention or key-value attention. In this blog post, I focus on the historically first and arguably the simplest one — additive attention.

Additive attention uses a single-layer feedforward neural network with hyperbolic tangent nonlinearity to compute the weights a_{ij}:

    \[f_{att}(\mathbf{h}_i, \mathbf{s}_j) = \mathbf{v}_a{}^\top \text{tanh}(\mathbf{W}_1 \mathbf{h}_i + \mathbf{W}_2 \mathbf{s}_j)\]

where \mathbf{W}_1 and \mathbf{W}_2 are matrices corresponding to the linear layer and \mathbf{v}_a is a scaling factor.

PyTorch Implementation of Additive Attention

class AdditiveAttention(torch.nn.Module):	 	 
    def __init__(self, encoder_dim=100, decoder_dim=50):	 	 
        super().__init__()	 	 

        self.encoder_dim = encoder_dim	 	 
        self.decoder_dim = decoder_dim	 	 
        self.v = torch.nn.Parameter(torch.rand(self.decoder_dim))	 	 
        self.W_1 = torch.nn.Linear(self.decoder_dim, self.decoder_dim)	 	 
        self.W_2 = torch.nn.Linear(self.encoder_dim, self.decoder_dim)	 	 

    def forward(self, 	 	 
      query, # [decoder_dim]	 	 
      values # [seq_length, encoder_dim]	 	 
    ):	 	 
        weights = self._get_weights(query, values) # [seq_length]	 	 
        weights = torch.nn.functional.softmax(weights, dim=0)	 	 
        return weights @ values # [encoder_dim]	 	 

    def _get_weights(self, 	 	 
      query, # [decoder_dim]	 	 
      values # [seq_length, encoder_dim]	 	 
    ):	 	 
        query = query.repeat(values.size(0), 1) # [seq_length, decoder_dim]	 	 
        weights = self.W_1(query) + self.W_2(values) # [seq_length, decoder_dim]	 	 
        return torch.tanh(weights) @ self.v # [seq_length]

Here _get_weights corresponds to f_{att}, query is a decoder hidden state \mathbf{h}_i and values is a matrix of encoder hidden states \mathbf{s}. To keep the illustration clean, I ignore the batch dimension.

In practice, the attention mechanism handles queries at each time step of text generation.

Here context_vector corresponds to \mathbf{c}_i. h and c are LSTM’s hidden states, not crucial for our present purposes.

Finally, it is now trivial to access the attention weights a_{ij} and plot a nice heatmap.

attention = AdditiveAttention(encoder_dim=100, decoder_dim=50)	 	 
encoder_hidden_states = torch.rand(10, 100)	 	 
decoder_hidden_states = torch.rand(13, 50)	 	 
weights = torch.FloatTensor(13, 10)	 	 
for step in range(decoder_hidden_states.size(0)):	 	 
    context_vector = attention(decoder_hidden_states[step], encoder_hidden_states)	 	 
    weights[step] = attention._get_weights(decoder_hidden_states[step], encoder_hidden_states)	 	 
seaborn.heatmap(weights.detach().numpy())
attention heatmap
A heatmap of attention weights

Here each cell corresponds to a particular attention weight a_{ij}. For a trained model and meaningful inputs, we could observe patterns there, such as those reported by Bahdanau et al. — the model learning the order of compound nouns (nouns paired with adjectives) in English and French. Let me end with this illustration of the capabilities of additive attention.

attention heatmap 2
Illustration of the capabilities of additive attention

  1. Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova (2019). BERT: Pre-training of deep bidirectional transformers for language understanding. Annual Conference of the North American Chapter of the Association for Computational Linguistics. ↩︎
  2. Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio (2015). Neural Machine Translation by Jointly Learning to Align and Translate. International Conference on Learning Representations. ↩︎

Related Posts

We use cookies to improve your experience on our website and to show you relevant content. Learn more.