TQ3_0 (TurboQuant 3-bit) KV Cache Quantization for llama.cpp
This is a fork of llama.cpp that adds TQ3_0 — a 3-bit KV cache quantization type implementing Google Research’s TurboQuant pipeline (ICLR 2026), which combines PolarQuant random rotation with QJL residual correction. It achieves near-lossless quality (4.6% PPL degradation) at 4.6× K-cache compression by combining a Walsh-Hadamard rotation with optimal scalar codebook quantization.
What Is TQ3_0?
TQ3_0 (TurboQuant 3-bit, revision 0) implements the TurboQuant pipeline for KV cache compression:
Walsh-Hadamard Rotation — A fast orthogonal transform that Gaussianizes block values
QJL Residual Signs — 1-bit sign of the quantization residual for error correction
Each block of 32 values is stored as:
Field
Size
Description
qs[8]
8 bytes
32 × 2-bit codebook indices (4 per byte)
qr[4]
4 bytes
32 × 1-bit QJL residual signs
gamma
2 bytes (FP16)
Per-block scale factor
Total
14 bytes
= 3.5 bits per value
Why It Works
The key insight from TurboQuant: a random orthogonal rotation makes any input distribution approximately Gaussian (by the Central Limit Theorem). Once Gaussianized, a fixed 4-level codebook quantizer achieves near-optimal MSE without any data-dependent calibration.
We use a per-block Walsh-Hadamard Transform (WHT32) as the rotation:
Deterministic — No stored state, no random seeds, perfectly reproducible
Self-inverse — The same transform is used for both encoding and decoding
O(n log n) — Only 160 add/subtract ops for 32 values (5 butterfly stages)
Fixed sign flips — A random ±1 preconditioning pattern breaks input structure
TQ3_0 is more compact than Q4_0 (14 vs 18 bytes per block) while using a fundamentally different encoding that preserves inner product structure — critical for attention computation.
Quality Benchmarks
All benchmarks on AMD Radeon 8060S (Strix Halo APU, 128GB UMA), ROCm 7.2.
Perplexity (wikitext-2)
KV Cache Type
PPL
Δ from F16
Model
F16 (baseline)
15.49
—
Qwen3.5-0.8B
TQ3_0
16.20
+4.6%
Qwen3.5-0.8B
4.6% perplexity degradation at 4.6× K-cache compression — near-lossless quality.
Throughput — GPU (Radeon 8060S, ROCm 7.2, ngl=99)
K-Cache
Bits/Val
pp512 (t/s)
Δ
tg128 (t/s)
Δ
F16
16.0
7,656
—
181.8
—
Q8_0
8.5
7,626
-0.4%
179.3
-1.4%
Q4_0
4.5
7,320
-4.4%
179.1
-1.5%
TQ3_0
3.5
7,358
-3.9%
177.9
-2.1%
TQ3_0 matches Q4_0 speed while using 22% less memory (14 vs 18 bytes per block).
Flash Attention is auto-disabled when using TQ3_0 K-cache. The system falls back to the standard attention path with dequantize-to-F32 + cuBLAS MUL_MAT. No manual configuration needed.
V-cache can be any supported type (F16, Q8_0, Q4_0, etc.) independently of K-cache.
The --flash-attn auto default works correctly — it detects TQ3_0 and disables FA.
Technical Implementation
Algorithm
The TQ3_0 quantization pipeline implements the TurboQuant paper’s core algorithm at the per-block level:
Forward (quantize):
1. Copy 32 input values to temp buffer
2. Apply random sign flips (fixed ±1 pattern)
3. Walsh-Hadamard butterfly transform (5 stages of add/sub)
4. Normalize by 1/√32
5. Find amax → compute scale d = amax / 1.51
6. Quantize each value to nearest centroid: {-1.51, -0.453, +0.453, +1.51}
7. Pack 2-bit index (4 per byte) → qs[8]
8. Compute residual sign → qr[4]
9. Store scale as FP16 → gamma
Inverse (dequantize):
1. Unpack 2-bit indices → lookup centroid → multiply by scale
2. Walsh-Hadamard butterfly transform (5 stages)
3. Normalize by 1/√32 and undo sign flips
4. Output reconstructed values
CPU WHT + codebook quantize/dequantize + validation
CPU Backend (5 files)
File
Change
ggml/src/ggml-cpu/quants.h
quantize_row_tq3_0 declaration
ggml/src/ggml-cpu/quants.c
quantize_row_tq3_0 wrapper
ggml/src/ggml-cpu/ggml-cpu.c
type_traits_cpu entry for TQ3_0
ggml/src/ggml-cpu/ggml-cpu.cpp
NULL vec_dot guards for MUL_MAT/FA
ggml/src/ggml-cpu/ops.cpp
7 switch-case fallthroughs
GPU Backend (CUDA/HIP)
File
Change
ggml/src/ggml-cuda/common.cuh
ggml_cuda_type_traits<TQ3_0>
ggml/src/ggml-cuda/cpy-utils.cuh
Device WHT32 + quantize_f32_tq3_0_block
ggml/src/ggml-cuda/convert.cu
dequantize_block_tq3_0 with cooperative inverse WHT (shared mem)
ggml/src/ggml-cuda/set-rows.cu
SET_ROWS dispatch for TQ3_0
ggml/src/ggml-cuda/ggml-cuda.cu
MUL_MAT + SET_ROWS support, MMVQ exclusion
ggml/src/ggml-cuda/vecdotq.cuh
vec_dot_tq3_0_q8_1 fused kernel (reserved)
ggml/src/ggml-cuda/mmvq.cu
MMVQ dispatch registration
CLI & Inference Integration (3 files)
File
Change
common/arg.cpp
tq3_0 in the KV cache type allowlist
src/llama-context.cpp
Auto-disable FlashAttention for TQ3_0 K-cache
tools/llama-bench/llama-bench.cpp
tq3_0 in the benchmark type parser
Design Decisions
Per-block WHT (not full-d rotation) — The paper uses a d×d random orthogonal matrix. We approximate this with a per-block 32×32 Walsh-Hadamard Transform. This avoids modifying the attention computation graph while still achieving good quality (+5.8% PPL). A full head_dim rotation would require graph-level query rotation.
Fused MMVQ with WHT on query — Since WHT is orthogonal, dot(q, k) = dot(WHT(q), WHT(k)). Rather than dequantizing K back to original space, we apply WHT to the Q8_1 query values inside the fused vec_dot_tq3_0_q8_1 kernel (int32 butterfly transform), then compute the dot product directly in rotated space. This avoids the dequant+MUL_MAT path, achieving speed parity with Q4_0.
No Flash Attention kernel — FA uses warp-level MMA instructions with tile-based cooperative loading. Integrating WHT inverse into the tiled inner loop would require significant custom kernel code. The standard attention path with fused MMVQ is sufficient (<2% speed difference).
Fixed sign pattern — The ±1 sign flips before WHT use a fixed pattern (not runtime-random). This ensures deterministic behavior without needing to store or communicate random state.
Running the Diff
To see exactly what changed from upstream llama.cpp:
This project is built on top of llama.cpp by Georgi Gerganov and the ggml-org community. All the original llama.cpp documentation, features, and model support remain fully functional. See the upstream README for the full llama.cpp documentation.
llama-turboquant
This is a fork of llama.cpp that adds TQ3_0 — a 3-bit KV cache quantization type implementing Google Research’s TurboQuant pipeline (ICLR 2026), which combines PolarQuant random rotation with QJL residual correction. It achieves near-lossless quality (4.6% PPL degradation) at 4.6× K-cache compression by combining a Walsh-Hadamard rotation with optimal scalar codebook quantization.
What Is TQ3_0?
TQ3_0 (TurboQuant 3-bit, revision 0) implements the TurboQuant pipeline for KV cache compression:
{-1.51, -0.453, +0.453, +1.51}tuned for Gaussian distributionsEach block of 32 values is stored as:
qs[8]qr[4]gammaWhy It Works
The key insight from TurboQuant: a random orthogonal rotation makes any input distribution approximately Gaussian (by the Central Limit Theorem). Once Gaussianized, a fixed 4-level codebook quantizer achieves near-optimal MSE without any data-dependent calibration.
We use a per-block Walsh-Hadamard Transform (WHT32) as the rotation:
Comparison to Other Formats
TQ3_0 is more compact than Q4_0 (14 vs 18 bytes per block) while using a fundamentally different encoding that preserves inner product structure — critical for attention computation.
Quality Benchmarks
All benchmarks on AMD Radeon 8060S (Strix Halo APU, 128GB UMA), ROCm 7.2.
Perplexity (wikitext-2)
Throughput — GPU (Radeon 8060S, ROCm 7.2, ngl=99)
Perplexity (Qwen3.5-0.8B-Q5_K_M, wikitext-2, 10 chunks)
Where TQ3_0 Matters Most
TQ3_0’s primary benefit is memory savings, enabling longer contexts and more concurrent sessions:
Building
TQ3_0 is compiled into both the CUDA (NVIDIA) and HIP (AMD) backends automatically. No extra flags needed.
AMD GPU (ROCm / HIP)
Common
AMDGPU_TARGETSvalues: | GPU Family | Target | |———–|——–| | Strix Halo (8060S) |gfx1151| | RDNA 3 (RX 7900 XTX) |gfx1100| | RDNA 2 (RX 6900 XT) |gfx1030| | MI300X |gfx942|NVIDIA GPU (CUDA)
CPU-only (fallback)
TQ3_0 works on CPU via dequantization to F32 (slower than GPU but functional).
Usage
Basic Inference
Server Mode
Benchmarking
Important Notes
--flash-attn autodefault works correctly — it detects TQ3_0 and disables FA.Technical Implementation
Algorithm
The TQ3_0 quantization pipeline implements the TurboQuant paper’s core algorithm at the per-block level:
Files Modified
Core Type System (5 files)
ggml/include/ggml.hGGML_TYPE_TQ3_0 = 41in the enumggml/src/ggml-common.hblock_tq3_0struct (14 bytes: qs[8] + qr[4] + gamma)ggml/src/ggml.cggml/src/ggml-quants.hggml/src/ggml-quants.cCPU Backend (5 files)
ggml/src/ggml-cpu/quants.hquantize_row_tq3_0declarationggml/src/ggml-cpu/quants.cquantize_row_tq3_0wrapperggml/src/ggml-cpu/ggml-cpu.ctype_traits_cpuentry for TQ3_0ggml/src/ggml-cpu/ggml-cpu.cppvec_dotguards for MUL_MAT/FAggml/src/ggml-cpu/ops.cppGPU Backend (CUDA/HIP)
ggml/src/ggml-cuda/common.cuhggml_cuda_type_traits<TQ3_0>ggml/src/ggml-cuda/cpy-utils.cuhquantize_f32_tq3_0_blockggml/src/ggml-cuda/convert.cudequantize_block_tq3_0with cooperative inverse WHT (shared mem)ggml/src/ggml-cuda/set-rows.cuggml/src/ggml-cuda/ggml-cuda.cuggml/src/ggml-cuda/vecdotq.cuhvec_dot_tq3_0_q8_1fused kernel (reserved)ggml/src/ggml-cuda/mmvq.cuCLI & Inference Integration (3 files)
common/arg.cpptq3_0in the KV cache type allowlistsrc/llama-context.cpptools/llama-bench/llama-bench.cpptq3_0in the benchmark type parserDesign Decisions
Per-block WHT (not full-d rotation) — The paper uses a d×d random orthogonal matrix. We approximate this with a per-block 32×32 Walsh-Hadamard Transform. This avoids modifying the attention computation graph while still achieving good quality (+5.8% PPL). A full head_dim rotation would require graph-level query rotation.
Fused MMVQ with WHT on query — Since WHT is orthogonal,
dot(q, k) = dot(WHT(q), WHT(k)). Rather than dequantizing K back to original space, we apply WHT to the Q8_1 query values inside the fusedvec_dot_tq3_0_q8_1kernel (int32 butterfly transform), then compute the dot product directly in rotated space. This avoids the dequant+MUL_MAT path, achieving speed parity with Q4_0.No Flash Attention kernel — FA uses warp-level MMA instructions with tile-based cooperative loading. Integrating WHT inverse into the tiled inner loop would require significant custom kernel code. The standard attention path with fused MMVQ is sufficient (<2% speed difference).
Fixed sign pattern — The ±1 sign flips before WHT use a fixed pattern (not runtime-random). This ensures deterministic behavior without needing to store or communicate random state.
Running the Diff
To see exactly what changed from upstream llama.cpp:
References
Credits
This project is built on top of llama.cpp by Georgi Gerganov and the ggml-org community. All the original llama.cpp documentation, features, and model support remain fully functional. See the upstream README for the full llama.cpp documentation.
License: MIT (same as upstream llama.cpp)