Back to Roadmap06

Multi-Head Attention

The limitation of single-head attention and how multiple heads overcome it.

The previous chapter introduced single-head attention, where each token examines the available context, measures relevance, and gathers information from other tokens. The representation of each token is now shaped by its context rather than fixed at the start.

However, there's a limitation to what a single attention head can capture, and understanding this limitation explains why transformers use multiple heads running in parallel rather than just one.

The Limitation of One Head

Single-head attention uses one set of projection matrices: W_Q, W_K, and W_V. These transform each token's embedding into a Query, a Key, and a Value. With one W_Q and one W_K, there is only one way to measure similarity between tokens, through the dot product of their projected Query and Key.

The result is one attention distribution based on this single similarity measure. The core issue is that with one set of projections, multiple notions of similarity are forced into a single attention pattern rather than being represented separately.

Multi-Head Attention

Multi-head attention addresses these limitations by running multiple attention operations in parallel. Each operation is called a head, and each head has its own independent set of projection matrices: its own W_Q, W_K, and W_V.

With separate projection matrices, each head defines its own way of measuring similarity through its own W_Q and W_K. As a result, tokens can attend to each other based on multiple notions of similarity at once.

Each head computes its own attention distribution, so within each head, weights still sum to 1 and positions still compete for attention weight. But across heads, a position that receives weak attention in one can receive strong attention in another, and the final result combines outputs from all heads.

How It Works

Now that we understand the goal of capturing multiple independent relationships, let's walk through the mechanics of how the model achieves this.

We start by projecting the input X into distinct Query, Key, and Value matrices for each head using independent weights (W_Q^i, W_K^i, W_V^i). Since each head has its own unique projection, it maps the input into a different subspace. This allows each head to define "relevance" in its own way, enabling the model to simultaneously attend to information based on completely different criteria.

Qi = X × WQi
Ki = X × WKi
Vi = X × WVi

With these specific projections, each head runs the standard attention function. It computes similarity scores using its own Queries and Keys, producing a unique attention distribution. All h heads perform this computation simultaneously.

headi = Attention(Qi, Ki, Vi)

Finally, the independent streams are merged. Concatenating the outputs from all heads restores the full dimensionality, yet keeps the information from each head isolated in its own segment. The subsequent multiplication with W_O allows these segments to interact. This projection mixes the distinct features discovered by different heads, integrating them into a single, unified representation.

MultiHead(X) = Concat(head1, ..., headh) × WO

If each head operated on the full model dimension d_model, adding more heads would increase the computational cost by a factor of h. To prevent this, standard implementations scale the projected size of each head down: d_k = d_model / h.

For a model with d_model = 768 and 12 heads, this means each head projects into just 64 dimensions. Because the heads are narrower, the total number of parameters across all h heads matches that of a single full-size head, giving us multiple perspectives without a massive increase in compute.

h × (d_model × d_model/h) = d_model × d_model

Summary
  • Single-head attention uses one set of weights, forcing the model to capture only a single type of relationship at a time
  • Multi-head attention projects the input into h independent subspaces, allowing it to attend to multiple distinct patterns simultaneously
  • The independent result vectors from each head are concatenated and projected through W_O to integrate them into a unified representation
  • By splitting the model dimension across heads (d_k = d_model/h), the parameter count and asymptotic compute remain similar to a single full-size head

Each token has now integrated information from its available context. The next chapter introduces the feed-forward network, which processes these enriched vectors independently.