From Heads to Factors: A Deep Dive into Tensor Product Attention and the T6 Transformer

A Transformer layer must preserve every key–value pair for every head, layer, and past token—a memory bill that rises linearly with context length.

Self-attention overview

1 Why the KV Cache Became Public Enemy #1

Modern language models rarely choke on parameter counts; they choke on context.
During autoregressive generation the model must keep two tensors—keys and values—for every past token so that future queries can look back.
With H heads of size dh and a sequence of T tokens, the cached footprint for one layer is

\[\boxed{\text{Memory}_{\text{MHA}} = 2\,H\,d_{h}\,T} \tag{1}\]

On an A100 GPU you can exhaust 40 GB of RAM long before you fill the compute budget.
Previous fixes (Multi-Query, Grouped-Query, Multi-head Latent) claw back memory by sharing or summarising keys, but at some cost in modelling power.
Tensor Product Attention (TPA)—introduced in “Tensor Product Attention Is All You Need” (2025)—compresses each Q, K & V with a low-rank factorisation, keeping almost all capacity while paying only a fraction of the memory bill.


2 A One-Equation Recap of Multi-Head Attention

For a single head the scaled-dot-product attention is

\[\mathrm{Attn}(Q,K,V) \;=\; \operatorname{softmax}\!\left(\frac{QK^{\!\top}}{\sqrt{d_{h}}}\right)V . \tag{2}\]

Multi-Head Attention (MHA) just repeats that calculation H times with independent projections—expressive, but memory-hungry via Equation (1).


3 Tensor Product Attention — Compress Keys, Keep Power

Outer-product factorisation

A rank-1 outer product; TPA stores a sum of such products rather than a full matrix.

Instead of caching a full H × dh key matrix for every token, TPA writes it as a low-rank sum of outer products

\[K \;=\; \sum_{r=1}^{R} a_{r}\;\otimes\;b_{r}^{\!\top}, \tag{3}\]

where

The cache now needs only the skinny vectors (a_{r}, b_{r}), so per-token memory becomes

\[\boxed{\text{Memory}_{\text{TPA}} = 2\,R\,(H+d_{h})\,T}, \qquad R \ll H . \tag{4}\]

With LLaMA-style numbers ((H{=}16, d_{h}{=}128, R{=}8)) the key-value slice shrinks ≈ 8 ×.

A Spectrum, Not a Point

Configure the factors like… …and you recover
(R=H,\;a_{r}=e_{r}) Multi-Head Attention
(R=1,\;a_{1}=\mathbf{1}) Multi-Query Attention
Repeating blocks in (a_{r}) Grouped-Query Attention

4 Position Embeddings That Stay Low-Rank

Rotary Position Embeddings (RoPE) rotate queries & keys in complex space for unlimited extrapolation.
Because the rotation is block-diagonal, we can apply it directly to each factor in Equation (3), so TPA stays compressed—unlike earlier low-rank schemes that needed an uncompressed “RoPE slice.”


5 T6: The First Transformer Built on TPA

Transformer stack

The authors swap TPA blocks into a LLaMA-like architecture (norm-first, SwiGLU MLPs) and call the model T6.

Model Params Context length Val. ppl ↓ KV memory
Baseline MHA 353 M 8 k 5.78 1 ×
T6 (TPA) 353 M 16 k 5.52 ≈ 0.12 ×

Zero-shot evaluation (ARC, BoolQ, HellaSwag…) shows consistent +1–2 % accuracy over MHA, MQA, GQA and MLA.
Because the KV cache is slimmer, the same A100 now handles 65 k-token contexts at inference with speed almost identical to MQA.


6 What Practitioners Can Do Right Now


7 Limitations & Open Questions


8 Conclusion — A Smarter Coordinate System for Attention

Tensor Product Attention reframes the long-context dilemma: re-encode information in a basis where it’s cheaper to store.
Because MHA, MQA and GQA are limiting cases, TPA offers a tunable continuum rather than discrete compromises.
The T6 experiments show you can have low perplexity and huge context windows and small caches—all on today’s hardware.


Reference

Jiang et al.  —  Tensor Product Attention Is All You Need (2025).
https://arxiv.org/abs/2501.06425