From Heads to Factors: A Deep Dive into Tensor Product Attention and the T6 Transformer
A Transformer layer must preserve every key–value pair for every head, layer, and past token—a memory bill that rises linearly with context length.
1 Why the KV Cache Became Public Enemy #1
Modern language models rarely choke on parameter counts; they choke on context.
During autoregressive generation the model must keep two tensors—keys and values—for every past token so that future queries can look back.
With H heads of size dh and a sequence of T tokens, the cached footprint for one layer is
On an A100 GPU you can exhaust 40 GB of RAM long before you fill the compute budget.
Previous fixes (Multi-Query, Grouped-Query, Multi-head Latent) claw back memory by sharing or summarising keys, but at some cost in modelling power.
Tensor Product Attention (TPA)—introduced in “Tensor Product Attention Is All You Need” (2025)—compresses each Q, K & V with a low-rank factorisation, keeping almost all capacity while paying only a fraction of the memory bill.
2 A One-Equation Recap of Multi-Head Attention
For a single head the scaled-dot-product attention is
\[\mathrm{Attn}(Q,K,V) \;=\; \operatorname{softmax}\!\left(\frac{QK^{\!\top}}{\sqrt{d_{h}}}\right)V . \tag{2}\]Multi-Head Attention (MHA) just repeats that calculation H times with independent projections—expressive, but memory-hungry via Equation (1).
3 Tensor Product Attention — Compress Keys, Keep Power
A rank-1 outer product; TPA stores a sum of such products rather than a full matrix.
Instead of caching a full H × dh key matrix for every token, TPA writes it as a low-rank sum of outer products
\[K \;=\; \sum_{r=1}^{R} a_{r}\;\otimes\;b_{r}^{\!\top}, \tag{3}\]where
- (a_{r}\in\mathbb{R}^{H}) are head-factors,
- (b_{r}\in\mathbb{R}^{d_{h}}) are token-factors,
- (R \ll H) is the chosen rank.
The cache now needs only the skinny vectors (a_{r}, b_{r}), so per-token memory becomes
\[\boxed{\text{Memory}_{\text{TPA}} = 2\,R\,(H+d_{h})\,T}, \qquad R \ll H . \tag{4}\]With LLaMA-style numbers ((H{=}16, d_{h}{=}128, R{=}8)) the key-value slice shrinks ≈ 8 ×.
A Spectrum, Not a Point
Configure the factors like… | …and you recover |
---|---|
(R=H,\;a_{r}=e_{r}) | Multi-Head Attention |
(R=1,\;a_{1}=\mathbf{1}) | Multi-Query Attention |
Repeating blocks in (a_{r}) | Grouped-Query Attention |
4 Position Embeddings That Stay Low-Rank
Rotary Position Embeddings (RoPE) rotate queries & keys in complex space for unlimited extrapolation.
Because the rotation is block-diagonal, we can apply it directly to each factor in Equation (3), so TPA stays compressed—unlike earlier low-rank schemes that needed an uncompressed “RoPE slice.”
5 T6: The First Transformer Built on TPA
The authors swap TPA blocks into a LLaMA-like architecture (norm-first, SwiGLU MLPs) and call the model T6.
Model | Params | Context length | Val. ppl ↓ | KV memory |
---|---|---|---|---|
Baseline MHA | 353 M | 8 k | 5.78 | 1 × |
T6 (TPA) | 353 M | 16 k | 5.52 | ≈ 0.12 × |
Zero-shot evaluation (ARC, BoolQ, HellaSwag…) shows consistent +1–2 % accuracy over MHA, MQA, GQA and MLA.
Because the KV cache is slimmer, the same A100 now handles 65 k-token contexts at inference with speed almost identical to MQA.
6 What Practitioners Can Do Right Now
- Serve longer prompts on fixed hardware → factorise K & V (rank 8), quick fine-tune, context ≈ 8 × longer.
- Train mid-size models faster → full TPA converges in fewer steps than MHA for the same FLOPs.
- Edge/mobile → pair rank-8 TPA with 4-bit weight quantisation; both params and cache fit on laptop-class GPUs.
7 Limitations & Open Questions
- GPU kernels are ready; CPU/Metal back-ends lag.
- Rank (R) is empirical (paper uses 8); auto-tuning is open research.
- Ultra-long contexts (>100 k) may need ALiBi or gradient-injection tricks.
- Vision/audio Transformers still untested with TPA.
8 Conclusion — A Smarter Coordinate System for Attention
Tensor Product Attention reframes the long-context dilemma: re-encode information in a basis where it’s cheaper to store.
Because MHA, MQA and GQA are limiting cases, TPA offers a tunable continuum rather than discrete compromises.
The T6 experiments show you can have low perplexity and huge context windows and small caches—all on today’s hardware.
Reference
Jiang et al. — Tensor Product Attention Is All You Need (2025).
https://arxiv.org/abs/2501.06425