Attention¶
Kernel Regression¶
The Nadaraya-Watson estimator is a fundamental nonparametric regression technique that uses kernel functions to weight observations based on their proximity to the query point. This approach can be viewed as an early form of attention mechanism, where the model "attends" to training examples based on their relevance to the current query.
Given a dataset of input-output pairs \(\{(\mathbf{x}_i, y_i)\}_{i=1}^n\), the Nadaraya-Watson estimator predicts the output at a new query point \(\mathbf{q}\) as:
where \(K(\mathbf{q}, \mathbf{x}_i)\) is a kernel function that measures the similarity between the query point \(\mathbf{q}\) and the training point \(\mathbf{x}_i\).
Some common kernels are
This can be rewritten in a form that highlights its connection to attention mechanisms:
where the attention weights \(\alpha(\mathbf{q}, \mathbf{x}_i)\) are defined as:
These attention weights satisfy \(\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{x}_i) = 1\), forming a probability distribution over the training points.
Let us consider the Gaussian kernel:
Then we have the kernel weighted estimator is weighted softmax:
This means that if a key \(\mathbf{x}_i\) is close to the query \(\mathbf{q}\), then we will assign more weight \(\alpha(\mathbf{q}, \mathbf{x}_i)\) to \(y_i\), i.e., the output \(y_i\) will have more attention on the prediction.
We can generalize the above Gaussian kernel to a more general case:
where \(a(\mathbf{q}, \mathbf{k}_i)\) is a similarity function between the query \(\mathbf{q}\) and the key \(\mathbf{k}_i\). One of the most common choices is the dot product:
This leads to the attention mechanism.
Attention Mechanism¶
Given an input \(X\) matrix, there are three weight matrices \(W_q, W_k, W_v\) to learn such that
So the attention mechanism is:
where \(\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i)}{\sum_{j=1}^n \exp(\mathbf{q}^\top \mathbf{k}_j)}\) is the attention weight.
In attention, the query matches all keys softly, to a weight between 0 and 1. The keys’ values are multiplied by the weights and summed.
So to summarize the self-attention mechanism with the input \(X\), we have the following steps:
- Project the input \(X\) to three matrices \(Q, K, V\) with weight matrices \(W_q, W_k, W_v\).
- Compute the attention weights \(\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i)}{\sum_{j=1}^n \exp(\mathbf{q}^\top \mathbf{k}_j)}\) for each query \(\mathbf{q}\) and each key \(\mathbf{k}_i\).
- Multiply the values \(\mathbf{v}_i\) by the attention weights and sum them up to get the output.
In summary, we have
PyTorch Implementation of Self-Attention¶
We can implement the self-attention mechanism in PyTorch as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
"""
Args:
embed_dim (int): Dimensionality of the input embeddings (and output dimensions of the linear transforms).
"""
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
# Linear projection for queries, keys, and values.
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, X):
# Compute linear projections (queries, keys, values)
Q = self.W_q(X) # Shape: (batch_size, sequence_length, embed_dim)
K = self.W_k(X) # Shape: (batch_size, sequence_length, embed_dim)
V = self.W_v(X) # Shape: (batch_size, sequence_length, embed_dim)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # Shape: (batch_size, sequence_length, sequence_length)
scores = scores / math.sqrt(self.embed_dim)
attention_weights = F.softmax(scores, dim=-1) # Shape: (batch_size, sequence_length, sequence_length)
# Multiply the attention weights with the values to get the final output.
output = torch.matmul(attention_weights, V) # Shape: (batch_size, sequence_length, embed_dim)
return output
Multi-Head Attention¶
Attention treats each word’s representation as a query to access and incorporate information from a set of values. Attention is parallelizable, and solves bottleneck issues.
Multi-head attention extends the basic attention mechanism by allowing the model to jointly attend to information from different representation subspaces at different positions. Instead of performing a single attention function with \(d\)-dimensional keys, values, and queries, multi-head attention performs the attention function in parallel \(h\) times, with different, learned linear projections to \(d_k\), \(d_k\), and \(d_v\) dimensions. These parallel attention outputs, or "heads," are then concatenated and linearly transformed to produce the final output. This approach enables the model to capture different aspects of the input sequence simultaneously.
In specific, for each head \(i\), we have the linear weights \(W_q^{(i)}, W_k^{(i)}, W_v^{(i)}\) to map the input \(X\) to the head
Then we have the multi-head attention by concatenating all the heads together and project them to the output space:
PyTorch has a built-in function for the multi-head attention mechanism torch.nn.MultiheadAttention(embed_dim, num_heads)
where embed_dim
is the dimension of the input embeddings and num_heads
is the number of heads.