目录
SageAttention_MXMACA

SageAttention — MetaX MXMACA 平台移植版 (v3.0.0)

本项目是 SageAttentionNVIDIA CUDA → MetaX MXMACA 的完整平台移植,针对 MetaX C500/C550/C600 GPU 深度优化。

移植范围: SageAttention2++ 核心注意力管线
目标硬件: MXC500 系列 (xcore1000) / MXC600 系列 (xcore1500)
参考论文: SageAttention2++ (arXiv 2505.21136)


目录

  1. 项目背景与上游关系
  2. CUDA → MXMACA 功能对比矩阵
  3. 已移植功能详解
  4. MXMACA 架构优化
  5. 迁移技术难点与解决方案
  6. 沐曦硬件局限性分析
  7. 后续规划
  8. 开发反馈
  9. 安装
  10. 使用方法
  11. 测试
  12. 文档导航
  13. 引用

项目背景与上游关系

上游 SageAttention 是一系列高效的量化注意力算子,包含:

版本 量化方案 目标平台 论文
SageAttention (v1) INT8 QK × FP16 PV Ampere+ GPUs ICLR 2025
SageAttention2 INT4/INT8 QK × FP16 PV Ampere+ GPUs ICML 2025
SageAttention2++ INT4/INT8 QK × FP8/FP16 PV Ada/Hopper GPUs arXiv 2025
SageAttention3 NVFP4 QK × FP8 PV Blackwell GPUs NeurIPS 2025 Spotlight

本移植版基于 SageAttention2++ 架构(约 3,600 行 kernel + 调度层代码),覆盖 INT8 QK 量化 × FP16 PV 累加的全部 6 条计算路径,以及 per-thread / per-block 两种量化粒度。

不可移植的上游功能(受沐曦硬件限制)

上游功能 不可移植原因 依赖的 NVIDIA 特性
SageAttention2 INT4 QK 量化 MXMACA WMMA 不支持 INT4 MMA mma.sync.m16n16k64.s4.s4.s32
SageAttention2++ FP8 PV 量化 MXMACA 无 FP8 数据类型 mma.sync.m16n16k32.f8.f8.f32 (e4m3)
SageAttention2++ sm90 Hopper 路径 WGMMA/TMA 不可用 wgmma.mma_async、TMA bulk copy、mbarrier
SageAttention3 NVFP4 注意力 无 NVFP4 张量核 Blackwell 专用硬件特性
SM89 Ada per_channel_fp8 量化 无 FP8 类型 e4m3x4 转换指令

CUDA → MXMACA 功能对比矩阵

核心注意力管线对比

功能模块 NVIDIA CUDA (SM80/SM89/SM90) MXMACA (本移植版) 移植状态
QK 量化类型 INT8 / INT4 (SageAttention2) INT8 only ✅ INT8 完整
QK 量化粒度 per_warp / per_thread / per_block per_thread (mcTriton) / per_block (MACA kernel) / per_warp → per_block 降级 ✅ 两种粒度可用
PV 数据类型 FP16 / FP8 (e4m3) FP16 only ✅ FP16 完整
PV 累加器 FP32 / FP16 / FP16+FP32 (InstBuf) / FP32+FP32 (sm89) FP32 / FP16 / FP16+FP16 (InstBuf) ✅ 三种累加路径
V 均值融合 (smooth_v) ✅ (含 InstBuf + fuse_v_mean 组合) ✅ 完整
K 平滑 (smooth_k) ✅ (含 LSE 修正) ✅ 完整
GQA ✅ 任意比例 ✅ (已测 1:1 ~ 64:1) ✅ 完整
因果掩码 ✅ (后两个 CTA_K 迭代块应用) ✅ 完整
HND/NHD 布局 ✅ 完整
LSE 返回 ✅ 完整
head_dim=32 填充 ✅ (自动填充至 64) ✅ 完整
BF16 输入 ✅ (__nv_bfloat16) ✅ (maca_bfloat16, 含软件转换路径) ✅ 完整
torch.compile ✅ (torch.library.custom_op + fake tensor) ✅ 完整
变长序列 (varlen) Triton backend only ✅ (逐批次循环 + mcTriton fallback) ✅ 可用
SageAttention3 ✅ NVFP4 QK × FP8 PV ❌ 完全不可用 ❌ 硬件限制

底层技术与 ISA 对比

技术层 NVIDIA CUDA MXMACA (本移植版采用的替换方案)
MMA 指令 mma.sync PTX (直接) wmma::fragment + wmma::load_matrix_sync + wmma::mma_sync API
矩阵加载 ldmatrix.sync.aligned.x4 PTX wmma::load_matrix_sync()
异步拷贝 cp.async.ca.shared.global PTX → commit_group / wait_group __builtin_mxc_load_global_async128__builtin_mxc_arrive_gvmcnt(N)
WGMMA/TMA wgmma.mma_async + CUtensorMap (Hopper) 无等效 — 使用标准 WMMA m16n16
Wave/Warp 大小 32 lanes 64 lanes (MXMACA Runtime API §2.2)
Block 配置 dim3(32, num_warps) (32-thread warp) dim3(64, num_warps/2) (64-lane wave, 两个 MMA 子组)
Shuffle __shfl_xor_sync(0xffffffff, val, off, 32) __shfl_xor_sync(0xFFFFFFFFFFFFFFFFULL, val, off, 64)
整数 Reduction Butterfly shuffle 纯软件 __reduce_add_sync / __reduce_max_sync (MACA §2.21 原生指令)
共享内存 Swizzle XOR 模式 + __stcg 缓存写 XOR 模式软件实现(无 __stcg,普通写)
数学函数 __exp2f / __log2f / __tanhf / hfma2 __expf/__log2f(float), h2exp2/hexp2(half), tanh 为 float roundtrip
编译器 nvcc (-arch=sm_80/89/90) mxcc (--offload-arch=xcore1000/xcore1500, -x maca)
运行时 API cudaSetDevice / cudaStream_t / cudaEvent_t mcSetDevice / mcStream_t / mcEvent_t
Kernel 属性 cudaFuncSetAttribute(..., cudaFuncAttributeMaxDynamicSharedMemorySize, ...) mcFuncSetAttribute(..., mcFuncAttributeMaxDynamicSharedMemorySize, ...)
设备检测 at::kCUDA `at::kCUDA

已移植功能详解

6 条计算路径

所有路径共享 int8 QK × fp16 SV 基本架构,在累加器类型和平滑选项上区分:

# 路径名称 QK 类型 SV 累加器 V 均值融合 InstBuf 典型精度
1 accum_f32 INT8 FP32 高精度基线
2 accum_f32_fuse_v_mean INT8 FP32 高精度 + 平滑
3 accum_f16_fuse_v_mean_inst_buf INT8 FP16→FP32 SageAttention2++ 最高吞吐
4 accum_fp32 (per_thread) INT8 FP32 高精度 + 细粒度量化
5 accum_fp32 (per_block) INT8 FP32 高精度 + 块量化
6 accum_fp16 (InstBuf, per_thread) INT8 FP16→FP32 吞吐优先

注意力 kernel (csrc/qattn/qk_int_sv_f16_maca.cu, 1,343 行)

// 核心模板签名 (简化)
template<
    uint32_t CTA_Q=128, CTA_K=64, WARP_Q=32, WARP_K=64, head_dim,
    DataType DTypeQK=kInt8,
    QuantGranularity Q_GRAN, K_GRAN,
    DTypeSVAccum,           // float 或 half
    bool use_inst_buffer,   // FP16+FP16 双级累加
    DTypeOut,               // half 或 maca_bfloat16
    ComputeUnit DU,         // kTensorCore
    MaskMode,               // kNone 或 kCausal
    bool return_lse,
    bool fuse_v_mean>

计算流水线 (双缓冲异步 + 在线 softmax):

  1. Q 预加载: __builtin_mxc_load_global_async128 将 INT8 Q 从 global → smem;若 num_tiles_qk_inner=1 则将全部 Q tile 预加载至寄存器
  2. K/V 双缓冲: 两个 smem buffer 交替加载 K 和 V,commit → arrive → sync 与 MMA 计算重叠
  3. QK INT8 MMA: wmma::load_matrix_sync 从 XOR swizzled smem 加载 → wmma::mma_sync(m16n16k32, s8s8s32) → INT32 中间结果
  4. 在线 Softmax (Base-2): INT32→float 转换 → sm_scale * log2(e) 缩放 → 64-lane warp reduce 求 max → update_mdo (指数缩放 + 累加)
  5. SV FP16 MMA: wmma::load_matrix_sync 加载 V → wmma::mma_sync(m16n16k16, f16f16f32) (或 f16f16f16 for InstBuf)
  6. InstBuf 双级累加: FP16 MMA 中间结果存 RO_inst_buf[half]→ 逐元素 float 累加至 RO → 提升吞吐不损失精度
  7. 分母归一化: blockReduce 收集 dnormalize_d (用 __frcp_rn 近似倒数)
  8. 可选出: BF16 转换 (half2→float2→bfloat162), V mean 加回, LSE 计算, swizzled smem→global 写回

启动配置:

  • Block: dim3(64, (CTA_Q/WARP_Q) × (CTA_K/WARP_K)) = dim3(64, 4×1) = 256 threads
  • Grid: (div_ceil(qo_len, 128), num_qo_heads, batch_size)
  • 动态共享内存: max(CTA_Q×HD×sizeof(int8) + 2×CTA_K×HD×sizeof(int8+half), CTA_Q×HD×sizeof(half))

量化 kernel (csrc/fused/fused_maca.cu, 932 行)

Kernel 功能 备注
QuantInt8Kernel INT8 量化: absmax → float_to_int8_rn(x * 127/amax), 支持 sm_scale 预乘和 sub_mean 融合 block size 64/128
SubMeanKernel V 均值减法: __hsub2 (half2) 或 float roundtrip (bf16) 用于 smooth_v
TransposePadPermuteKernel 转置 + 填充 + 序列维度重排 (0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15) 为 FP8 MMA 设计,当前保留供未来扩展

mcTriton 后端 (8 个 kernel)

全部 8 个 Triton attention/quantization kernel 已适配 mcTriton (Triton 3.0.0 + MXMACA backend):

  • Q/K per-block INT8 量化 (定长 + 变长)
  • Q/K per-thread INT8 量化
  • 注意力前向 (非因果 + 因果)
  • 变长注意力 (非因果 + 因果)

适配要点: pipeline="cpasync", scenario="flashattn-fwd" 自动调优指令,pipeline="basic" for quant kernels,tl.interleave 替代 tl.cat (Triton 3.0 兼容性)。

Python 调度层 (sageattention/core.py, 644 行)

sageattn(q, k, v, ...)
├── is_metax_device() && is_maca_kernel_available()  → sageattn_qk_int8_pv_fp16_maca()
│   ├── quant: per_thread_int8_triton / per_block_int8_cuda
│   └── attn:  maca_compile → _qattn_maca C++ kernel (3 个入口点)
├── is_metax_device() && _TRITON_AVAILABLE           → sageattn_qk_int8_pv_fp16_triton()
│   └── quant + attn: mcTriton kernel
├── device.type == 'cuda'                            → 上游 CUDA 路径
└── else: RuntimeError

MXMACA 架构优化

本移植版针对 MXMACA 硬件特点做了以下原生优化:

优化项 技术细节 来源
64-lane Wave 映射 Block dim3(64, N/2), 每个 64-lane wave 并行驱动两个 32-lane WMMA 子组 — 相比朴素的 dim3(32, N) 映射提升 2× 吞吐 csrc/utils.cuh:10-26
INT8 WMMA wmma::fragment<...int8_t...> + wmma::mma_sync() (M16N16K32) csrc/mma_maca.cuh
FP16 WMMA wmma::fragment<...half...> + wmma::mma_sync() (M16N16K16), m16n16k16 由两个 m16n8k16 拼接 (N 维拆分) csrc/mma_maca.cuh:139-166
FP16→FP16 WMMA wmma::mma_sync() with __half accumulator — MACA §2.24.4 原生支持 csrc/mma_maca.cuh:169-182
async copy 流水线 __builtin_mxc_load_global_async128() 128-bit 异步加载 → __builtin_mxc_arrive_gvmcnt(N) 事务到达计数 → 双缓冲与 MMA 重叠 csrc/cp_async_maca.cuh
共享内存 Swizzle XOR 软件 swizzle (k32B/k64B/k128B), 消除 bank conflict csrc/permuted_smem_maca.cuh
Base-2 Softmax sm_scale *= log2(e), 使用 __expf(float) / h2exp2/hexp2(half) 硬件指令 csrc/math_maca.cuh
原生 Reduction 整数类型使用 MACA __reduce_add_sync / __reduce_max_sync (§2.21) 替代 butterfly shuffle csrc/reduction_utils_maca.cuh:46-51

迁移技术难点与解决方案

以下记录了从 NVIDIA CUDA 到 MXMACA 的完整迁移过程中遇到的关键技术挑战、分析过程和最终方案。

1. WMMA 接口层: PTX mma.sync → API wmma::mma_sync

挑战: 上游 NVIDIA kernel 使用 PTX 级 mma.sync.aligned.m16n16k32.row.col.s8.s8.s32 等指令直接操作寄存器片段 (uint32_t[4] / uint32_t[8])。MXMACA 不暴露等价的 PTX 级 MMA 指令。

分析过程: 查阅 MXMACA C++ 编程指南 §2.24,确认 MXMACA 提供 wmma:: namespace API 作为标准 MMA 接口。需要将 PTX 的原始 uint32 寄存器片段映射到 WMMA fragment 抽象。

解决方案 (在 csrc/mma_maca.cuh 中实现):

  • 针对每种数据类型和形状构造 wmma::fragment<> 对象
  • 提供 extract_fragment_data(frag, uint32_array) 进行 fragment → 原始寄存器的字节拷贝,保持与上游 kernel 的寄存器布局兼容
  • mma_sync_m16n16k16_row_col_f16f16f32 通过两次 m16n8k16 WMMA 调用拼接 N 维(与 NVIDIA mma.cuh 中的实现方式相同)
  • FP16 累加器的 mma_sync_m16n8k16_f16f16f16 路径经 MXMACA §2.24.4 确认原生支持
  • BF16 WMMA (mma.sync.m16n16k16.bf16.bf16.f32) 通过 wmma::fragment + maca_bfloat16 类型实现

2. 矩阵加载: PTX ldmatrixwmma::load_matrix_sync

挑战: NVIDIA kernel 大量使用 ldmatrix.sync.aligned.x4.trans.shared.b16 PTX 指令从共享内存加载 16×16 矩阵片段。MXMACA 没有等价 PTX 指令。

分析过程: 对比 MXMACA C++ 编程指南 §2.24 中的 WMMA API,wmma::load_matrix_sync() 是标准矩阵加载接口。但该 API 的正确使用前提是共享内存地址需满足 swizzle 对齐要求。

解决方案:

  • 统一使用 wmma::load_matrix_sync(fragment, smem_ptr, ldm) 替换所有 ldmatrix 调用
  • INT8 WMMA 加载要求 swizzle_mode == k128B (static_assert 保护 at mma_maca.cuh:55)
  • WMMA 自动处理行列转换,消除了 NVIDIA 版本的 ldmatrix_m8n8x4 vs ldmatrix_m8n8x4_trans 显式转换需求
  • 由于 load_matrix_sync 接受 ldm (leading dimension) 参数,不再需要手算 stride — 简化了共享内存访问逻辑

3. 异步数据传输: PTX cp.async__builtin_mxc_load_global_async128

挑战: NVIDIA kernel 的双缓冲异步流水线依赖 cp.async.ca.shared.global PTX + commit_group / wait_group 机制。MXMACA 提供不同的异步加载原语。

分析过程: 研究了 MXMACA Runtime API 编程指南和头文件,发现 MXMACA 提供:

  • __builtin_mxc_load_global_async128 — 128-bit 异步全局→寄存器加载
  • __builtin_mxc_arrive_gvmcnt(N) — global memory transaction count 到达同步
  • __builtin_mxc_arrive_bsmcnt(0) — shared memory (barrier) transaction count 到达同步

需要注意 MXMACA 会自动追踪所有异步加载事务(无需显式 commit),这与 CUDA 的 cp.async.commit_group 不同。

解决方案 (在 csrc/cp_async_maca.cuh 中实现):

  • commit_group()空操作 (MXMACA 硬件自动追踪)
  • wait_group<n>()__builtin_mxc_arrive_gvmcnt(n) (等待 ≤ n 个事务未完成)
  • load_128b(smem, gmem)__builtin_mxc_load_global_async128(gmem, smem) → 写入 uint4*
  • pred_load_128b(smem, gmem, pred) → 带零填充的谓词异步加载

重要设计决策 (2026-06-24): 保留 __builtin_mxc_* 底层原语而非迁移到 cooperative_groups::memcpy_async,因为:

  1. CG 提供批量拷贝 (collective) 而非 per-thread 128-bit 加载 — 与 kernel 的 per-lane 寻址模式不匹配
  2. CG 的 wait(group) 在计算前序列化所有加载 — 破坏双缓冲异步流水线
  3. CG 无谓词零填充等价功能
  4. 当前 namespace async 已提供清晰抽象层

4. 64-lane Wave vs 32-lane Warp

挑战: MXMACA 硬件 wave 为 64 lanes (warpSize = 64, Runtime API §2.2),而 NVIDIA warp 为 32。所有依赖 warpSize 的代码(shuffle、reduction、lane 标识)需要重评估。

分析过程: MXMACA 的 WMMA 在内部将 64-lane wave 拆分为两个 32-lane MMA 子组。因此 MMA 级别的 mma_lane_id 可以保持 32-lane 语义,但 warp 级 reduction 必须扩展至 64 lanes。

解决方案:

  • Block 配置: dim3(64, num_warps/2) 替代 dim3(32, num_warps) — 64 个 threadIdx.x 对应 64 lanes
  • Shuffle 掩码: __shfl_xor_sync(kMacaFullMask=0xFFFFFFFFFFFFFFFFULL, val, offset, 64) 替代 warpSize=32
  • Reduction butterfly: 需要 6 步 (64→32→16→8→4→2→1) 替代 5 步 (32→16→8→4→2→1)
  • 整数 reduction 优势: MACA 提供原生 __reduce_add_sync / __reduce_max_sync (§2.21) 用于 int32/uint32, 比手工 butterfly 更高效
  • MMA lane group 常量: kMmaLaneGroupSize = 32, kMmaLaneMask = 0xFFFFFFFF 用于 MMA 子组操作

5. 共享内存 Bank Conflict: 软件 Swizzle 实现

挑战: NVIDIA 版本通过 permuted_smem.cuh 中的 XOR swizzle 消除共享内存 bank conflict。MXMACA 需要等效的软件实现,且避免某些特殊指令。

分析过程: XOR swizzle 的核心逻辑是 offset = row * stride + xor(col, row_mod_pattern)。跨架构可移植。主要风险是 swizzle 粒度选择 (k32B/k64B/k128B) 和 MXMACA 共享内存的 bank 结构匹配。

解决方案 (在 csrc/permuted_smem_maca.cuh 中实现):

  • 三种粒度: k32B (直接映射,stride=2), k64B (XOR col with (row/2)%4, stride=4), k128B (XOR col with row%8, stride≥8)
  • advance_offset_by_column<step> / advance_offset_by_row<step> 正确处理 XOR 反转 (k128B 上步长 1/2/4 需要特殊 XOR 推进逻辑)
  • 编译兼容: cucc 无法解析模板成员函数 → 改为 free-standing template 函数 (所有 advance_offset_*, load_128b_*, store_128b_*)
  • __stcg (store cache-global): 在 MACA 上不可用 → store_128b_cg__MACA__ 下回退为 plain store

6. 数学函数适配

挑战: NVIDIA kernel 使用大量 half2/half 原生 PTX 数学指令。MXMACA 的函数库覆盖范围不同。

分析: 逐函数验证 MXMACA 支持情况。

函数 NVIDIA MXMACA 适配方式
__expf / __log2f (float) 直接映射
h2exp2 / hexp2 (half) 直接映射
__frcp_rn (float) 直接映射
__frsqrt_rn (float) 直接映射
__hadd2 / __hmul2 (half2) 直接映射
__hfma2 (half2 FMA) 直接映射
__float2half2_rn 直接映射
__float22half2_rn 直接映射
tanh (half/half2) ✅ (fast intrinsic) ❌ v3.7.0.x 无 half tanh 快速指令 float roundtrip via tanhf()
__exp2f (float) ✅ (PTX) ❌ 无独立 __exp2f intrinsic __expf(x * 0.6931471805599453f) 等价替代

7. BF16 支持: maca_bfloat16 类型系统

挑战: NVIDIA 使用 __nv_bfloat16 / __nv_bfloat162, MXMACA 使用 maca_bfloat16 / maca_bfloat162。转换语义和 WMMA 加载方式不同。

解决方案:

  • DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 宏映射 at::kBFloat16maca_bfloat16 (而非 __nv_bfloat16)
  • kernel 输出路径: half2float2 (通过 unpack_half2_from_uint32_to_float) → bfloat162 (通过 __float2bfloat16_rn)
  • Python 调度层: 若输入 V 为 BF16,在进入 kernel 前转为 FP16 (因当前 PV WMMA 仅使用 FP16 Tensor Core, BF16 WMMA 存但未在 attention 路径上启用)
  • BF16 WMMA 验证: test_mxmac_api_validation.py 中包含 mma.sync.m16n16k16.bf16.bf16.f32 正确性测试

8. mxcc 编译器适配

挑战: nvcc → mxcc 的编译选项体系差异显著,需要全新的构建系统配置。

解决方案 (在 setup.py 中实现):

  • 架构检测: _detect_mxcc_arch() — 通过 SAGEATTN_OFFLOAD_ARCH 环境变量 / mxcc --list-offload-arch / maca_version.h 解析三层回退机制,映射 C500→xcore1000, C600→xcore1500
  • 编译标志: -x maca (指定 MXMACA 语言模式), --offload-arch=xcore1000, -O2 (从 -O1 升级于 2025-06-24), --maxrregcount=96, -use-fast-math
  • Include 路径: $MACA_PATH/include, $MACA_PATH/mxgpu_llvm/include, $MACA_PATH/tools/cu-bridge/include
  • 设备链接: -fgpu-rdc 启用可重定位设备码(对应 nvcc 的 -rdc=true

9. 量化粒度降级: per_warp → per_block

挑战: MACA WMMA 的 mma.sync API 不支持 warp 级子 tile 量化粒度 (在单个 warp 内进一步分块)。上游 per_warp_int8 依赖此特性。

分析: 查阅 mma_maca.cuh 中的 WMMA fragment 定义和 MXMACA WMMA 指令限制,确认 WMMA 的最小操作单元为 16×16 tile,在 warp 级 (32 MMA lanes) 无法再细分。

解决方案 (在 sageattention/core.py:265 中实现):

  • qk_quant_gran="per_warp" 时,Python 调度层自动降级为 per_block_int8_cuda()
  • per_warp_int8 CUDA kernel 保留在 fused_maca.cu 中但未被调度器调用 — 作为未来硬件的参考实现

10. Varlen 逐批次循环的性能限制

当前实现: _sageattn_varlen_maca() 对每个 batch 元素逐一切片、重排、调用单序列注意力 kernel。每个 batch 产生一次独立 kernel launch。

已知性能瓶颈: Python 层 per-batch 排列开销 + per-batch kernel launch 延迟在大量小序列时累积显著。

记录为 TODO: sageattention/core.py:514-522 — 计划使用 MACA streams 实现 batch 间并行调度,或下沉为 batched varlen C++ kernel。


沐曦硬件局限性分析

以下基于 MXMACA SDK 3.7.0.x 和 MXC500/C600 硬件规格进行的系统评估:

当前硬件限制

限制项 影响的 SageAttention 功能 上游实现 MACA 替代方案
无 INT4 WMMA SageAttention2 INT4 QK mma.sync.m16n16k64.s4.s4.s32 INT8 QK 量化 (4× 更大内存 vs INT4)
无 FP8 数据类型 SageAttention2++ FP8 PV mma.sync.m16n16k32.f8.f8.f32 + per_channel_fp8 量化 FP16 PV (2× 更大内存 vs FP8)
无 WGMMA Hopper 大 tile MMA (m64n128k32) wgmma.mma_async 标准 WMMA m16n16 (更小 tile, 更多迭代)
无 TMA Hopper 硬件批量异步拷贝 + 自动 swizzle CUtensorMap + cp.async.bulk per-thread 128-bit async load (更高指令开销)
__stcg Cache-global store (L2 bypass hint) __stcg PTX 普通 global store (可能 L2 污染)
无 half tanh Softmax 的近似 tanh 快速路径 __tanhf half2 intrinsic float roundtrip via tanhf() (精度无损, 延迟略增)
SageAttention3 全部 NVFP4 QK × FP8 PV Blackwell NVFP4 TC + FP8 WGMMA 完全不可用

性能影响评估

  • INT8 QK vs INT4 QK: QK 矩阵内存占用为 NVIDIA INT4 路径的 2× (每个元素 1B vs 0.5B)。在极长序列场景下,内存带宽瓶颈更早到达。
  • FP16 PV vs FP8 PV: PV 矩阵内存为 NVIDIA FP8 路径的 2×。FP16 WMMA (M16N16K16) 的 K 维度是 FP8 WMMA (M16N16K32) 的一半,需更多 smem 迭代。
  • WMMA m16n16 vs WGMMA m64n128: tile 小 32× → 迭代次数多 32× → 寄存器压力和指令发射压力更高,但前端延迟更小。
  • per-thread 128b load vs TMA: 无 TMA 的硬件 swizzle 和 descriptor 开销节省,但每个线程的 load 指令数增加。

后续规划

P0 — 近期高优先级

项目 说明
Varlen batch 并行化 当前逐 batch 循环为 Python 层开销 + kernel launch 延迟瓶颈。方案: (1) MACA streams batch 间并行, (2) 实现真正的 batched varlen C++ kernel
InstBuf 全路径覆盖 当前 InstBuf 仅覆盖 fuse_v_mean 路径。补充 InstBuf + 其他配置组合,使最高吞吐路径覆盖所有使用场景

P1 — 性能调优

项目 说明
QK INT8 WMMA 数据复用 num_tiles_qk_inner == 1 时 Q 已预加载至寄存器 — 探索对其他 head_dim/WARP 配置启用此路径
mcTriton 后端性能调优 Profile mcTriton 后端延迟,优化 block size 和流水线策略。部分配置下 mcTriton 路径存在衰退
smem 大小调优 当前 smem 按最坏情况预算 — 针对 head_dim=64 场景可减少 50% smem 使用,提升 occupancy

P2 — 生态集成

项目 说明
MXMACA SDK 4.x 适配 适配新版 SDK 的 API 变更(mxcc 接口、runtime API),保持对 C500/C550/C600 三代硬件的兼容
分布式推理支持 在 MXMACA 平台上实现多卡推理 (对应上游 FIXME(DefTruth) 标注)
CI/CD 自动化测试 搭建 MetaX 设备上的 CI 流水线,自动化运行完整 45+ 项测试套件
性能基线发布 发布 MetaX C500/C600 上的完整性能基准数据(延迟 / 带宽 / TFLOPS),与 NVIDIA A100/H100/H20 对比
推理框架集成 提供 CogVideoX、HunyuanVideo 等热门视频生成模型的 MXMACA 推理示例 (对应上游 example/ 目录)

反馈沐曦

项目 说明
INT4 WMMA 需求 INT4 量化可将 QK 内存占用在 INT8 基础上再减半 — 对长序列推理和 MoE 模型有重大意义
FP8 支持需求 FP8 PV 路径可实现与 NVIDIA Hopper FP8 路径对等的性能水平,是缩小性能差距的关键一步
ldmatrix 等价指令 当前 wmma::load_matrix_sync 在多个 kernel 路径上引入了 fragment 构建/解构开销
__stcg 等价语义 Cache-global store hint 可减少共享内存→全局内存写回路径的 L2 污染

开发反馈

沐曦 MXMACA 生态体验总结

优势:

  1. 编程模型高度兼容: MXMACA 与 CUDA 的 Kernel 启动语法 (<<<grid, block, shmem, stream>>>), Cooperative Groups, 内存模型几乎完全一致。kernel 层代码迁移的主要工作是指令替换而非算法重写,迁移效率显著。

  2. WMMA API 覆盖充分: MXMACA WMMA API 覆盖了 INT8 (s8s8s32), FP16 (f16f16f32/f16f16f16), BF16 (bf16bf16f32) 三种常用 Tensor Core 类型。wmma::fragment + wmma::load_matrix_sync + wmma::mma_sync 三层 API 设计清晰。

  3. mxcc 编译器成熟度: 多代硬件架构的自动检测 (--offload-arch) 降低适配成本。--maxrregcount, -use-fast-math, -fgpu-rdc 等优化选项与 nvcc 对应选项语义一致。

  4. maca_bfloat16 类型完备: __bfloat162float, __float2bfloat16_rn, __bfloat162bfloat162 等转换函数与 CUDA 等价函数一一对应。

  5. mcTriton 后端: Triton 3.0.0 + MXMACA 后端的 pipeline="cpasync" / scenario="flashattn-fwd" 自动调优指令与 CUDA Triton 对齐,为 Python 层提供了有效的跨平台兼容路径。

待改进:

  1. INT4 WMMA 和 FP8 是补齐量化推理的关键瓶颈: 这两个数据类型支持是沐曦硬件在 AI 推理领域缩小与 NVIDIA 差距的最重要工程项。缺少它们意味着 SageAttention 的 INT4 QK 路径 (SageAttention2 核心特性) 和 FP8 PV 路径 (SageAttention2++ 最高吞吐) 完全不可用。

  2. cp.async 流水线深度控制: 与 NVIDIA TMA 相比,__builtin_mxc_load_global_async128 的流水线深度控制粒度更粗 (只有 gvmcnt 计数,无可编程 commit group),影响了复杂流水线的微调能力。

  3. cucc 模板限制: cucc (MXMACA C++ 编译器) 无法解析模板成员函数调用,导致 permuted_smem_maca.cuh 中所有 smem_t 的模板方法必须改为 free-standing template 函数。这引入了额外的代码重复和样板。

  4. 共享内存 bank 信息: 缺少 MXMACA 共享内存 bank 结构 (bank 数、bank 宽度、bank conflict 检测工具) 的官方文档,导致 swizzle 粒度选择依赖经验性试错而非精确建模。

技术迁移统计

指标 数值
迁移总代码量 ~3,600 行 (kernel + 调度)
注意力 kernel 1,343 行 (csrc/qattn/qk_int_sv_f16_maca.cu)
量化/util kernel 932 行 (csrc/fused/fused_maca.cu)
WMMA & 工具头文件 ~1,000 行 (6 个 .cuh 文件)
Python 调度层 1,119 行 (core.py + maca_compile.py + metax_utils.py + quant.py)
mcTriton kernel ~1,300 行 (8 个 Triton kernel)
测试代码 ~3,600 行 (45+ 测试用例)
NVIDIA ISA → MXMACA ISA 指令替换点 核心 6 大类: MMA、矩阵加载、async copy、shuffle/reduction、数学函数、编译选项
数值精度 cos >= 0.99 (所有 24 种路径 × head_dim × 布局组合)

安装

前置依赖

  • Python >= 3.9
  • PyTorch >= 2.3.0 (MXMACA-PyTorch)
  • MXMACA SDK >= 3.7.x
  • mxcc 编译器 (包含在 MXMACA SDK 中)

环境配置

export MACA_PATH=/opt/maca
export LD_LIBRARY_PATH=${MACA_PATH}/lib:${LD_LIBRARY_PATH}
export PATH=${MACA_PATH}/mxgpu_llvm/bin:${PATH}

从源码编译

git clone https://gitlink.org.cn/sunmy/SageAttention_MXMACA.git
cd SageAttention_MXMACA
python setup.py install

构建系统通过 mxcc --list-offload-arch 自动检测 GPU 架构:

硬件 offload-arch Compute Capability
C500 xcore1000 major=10, minor=00
C550 xcore1002 major=10, minor=02
C600 xcore1500 major=15 (预测)

环境变量

变量 默认值 说明
MACA_PATH /opt/maca MXMACA SDK 安装目录
MXCC mxcc MXMACA 编译器路径
SAGEATTN_OFFLOAD_ARCH (自动检测) 手动覆盖目标架构
SAGEATTN_SKIP_BUILD 0 设为 1 跳过 C++ 扩展编译
CXX_APPEND_FLAGS (空) 追加到 MXCC_FLAGS 的额外编译器参数

使用方法

基础调用

from sageattention import sageattn

# 自动检测 MetaX 设备,调度到最优路径
output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
  • q, k, vFP16/BF16 张量。
  • tensor_layout="HND" → 形状 (batch_size, head_num, seq_len, head_dim)
  • tensor_layout="NHD" → 形状 (batch_size, seq_len, head_num, head_dim)

可用 API

API 说明
sageattn 自动调度入口 (MetaX 设备 → MACA kernel, CUDA 设备 → Triton fallback)
sageattn_qk_int8_pv_fp16_maca MXMACA 原生 C++ 内核,支持全部可配置选项
sageattn_qk_int8_pv_fp16_triton mcTriton 后端 (非 MetaX GPU 的降级方案)
sageattn_varlen 变长序列注意力 (packed 格式)

高级配置

from sageattention.core import sageattn_qk_int8_pv_fp16_maca

output = sageattn_qk_int8_pv_fp16_maca(
    q, k, v,
    tensor_layout="HND",
    is_causal=False,
    qk_quant_type="int8",       # 仅支持 "int8" (沐曦硬件限制)
    qk_quant_gran="per_thread", # "per_warp" (降级 per_block) 或 "per_thread"
    pv_accum_dtype="fp32",      # "fp32" (高精度) / "fp16" (快) / "fp16+fp16" (InstBuf, 最高吞吐)
    sm_scale=None,              # 默认: 1/sqrt(head_dim), 自动应用 log2(e) 缩放
    smooth_k=True,              # K 均值平滑,消除离群值
    smooth_v=False,             # V 均值融合 (与 InstBuf 可组合)
    return_lse=False,           # 返回 log-sum-exp 用于多头拼接
)

6 条计算路径的参数组合:

pv_accum_dtype smooth_v 路径描述
"fp32" False INT8 QK + FP16 PV (FP32 累加) — 高精度基线
"fp32" True 上述 + 融合 V 均值
"fp16+fp16" True INT8 QK + FP16 PV (InstBuf 双级累加 + 融合 V 均值) — 最高吞吐
"fp16+fp16" False 不支持 (当前 InstBuf 仅实现 fuse_v_mean 路径)

测试

完整测试套件 (45+ 用例)

# 全部测试 (需要 MetaX 设备)
pytest tests/ -v

# 核心正确性测试 (~25 用例)
pytest tests/test_maca_sageattn.py -v

# 变长序列测试 (~7 用例)
pytest tests/test_maca_varlen_sageattn.py -v

# 跨平台精度测试 (24 种配置, cos >= 0.99)
pytest tests/test_cross_platform_precision.py -v

# mcTriton 后端测试
pytest tests/test_triton_backend.py -v

# MXMACA 硬件 API 验证 (WMMA, async copy, shuffle, BF16)
pytest tests/test_mxmac_api_validation.py -v

# 代码完整性检查 (无 INT4 残留, 文件完整)
pytest tests/test_cleanup_validation.py -v

性能基准

# 所有路径 + seq_len 扫描
python tests/bench_maca_sageattn.py --bench-only

# 自定义配置
python tests/bench_maca_sageattn.py --bench-only \
    --batch 2 --q-heads 16 --kv-heads 4 --head-dim 128 \
    --seq-lengths 1024,2048,4096,8192

# Causal 模式
python tests/bench_maca_sageattn.py --bench-only --causal

# 仅正确性验证
python tests/bench_maca_sageattn.py --correctness-only

测试覆盖矩阵

测试维度 内容
计算路径 6 条 (FP32/FP16/FP16+FP16 × fuse_v_mean 组合) × 2 head_dim (64/128) × 2 布局 (HND/NHD) = 24 配置
GQA 比例 1:1, 8:2, 8:1, 32:1, 64:1 (含输出多样性校验)
序列长度 短序列 (1/50/64/97), 中序列 (128/256/512), 长序列 (1024/2048/4096/8192/16384)
边界条件 单 token Q/KV, 块对齐/未对齐, 非 2 的幂, head_dim=32 填充
数据类型 FP16, BF16
掩码 因果/非因果
Varlen 均匀/变化/极端混合长度 (1:8192), varlen + GQA, varlen + causal

性能说明

  • INT8 QK 量化将 QK 矩阵乘法的内存占用降低 4 倍 (vs FP16),在长序列场景下显著提升内存受限吞吐量。
  • fp16+fp16 (InstBuf) PV 累加路径提供最高吞吐量: FP16 Tensor Core 中间结果 → FP32 最终输出, 兼顾速度与精度。
  • K 平滑 (smooth_k) 通过消除 QK 离群值提升量化精度,V 均值融合 (smooth_v) 在核函数内完成均值的减和加,零额外 kernel launch 开销。
  • Base-2 Softmax 利用 __expf / h2exp2 硬件指令避免 base-e 的 expf 调用,数值精度等价。
  • 64-lane wave 映射 (block dim3(64, N/2)) 相比朴素映射将 WMMA 利用率从 50% 提升至 ~100%。

文档导航

文档 说明
sageattention/core.py 主入口, 设备调度, 三种 API 的完整实现
sageattention/maca_compile.py torch.library.custom_op 注册 + fake tensor (torch.compile 支持)
sageattention/metax_utils.py MetaX 设备检测, wave size 查询, kernel 可用性检查
sageattention/quant.py per_block / per_warp INT8 量化, V 均值减法
csrc/qattn/qk_int_sv_f16_maca.cu 核心注意力 kernel (1,343 行) — 完整注意力流水线
csrc/fused/fused_maca.cu 量化/融合预处理 kernel (932 行)
csrc/mma_maca.cuh MXMACA WMMA 封装层 (INT8/FP16/BF16 MMA + rowsum)
csrc/cp_async_maca.cuh MXMACA 异步拷贝流水线抽象
csrc/permuted_smem_maca.cuh XOR 共享内存 swizzle 实现
setup.py 构建系统: mxcc 架构检测, 编译标志, 双 C++ 扩展
tests/ 完整测试套件 (正确性/边界/跨平台/GQA/varlen/BF16)

引用

如果您使用了本代码或认为我们的工作有价值,请引用:

@inproceedings{zhang2025sageattention,
  title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration}, 
  author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}
@inproceedings{zhang2024sageattention2,
  title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization},
  author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2025}
}
@article{zhang2025sageattention2++,
  title={Sageattention2++: A more efficient implementation of sageattention2},
  author={Zhang, Jintao and Xu, Xiaoming and Wei, Jia and Huang, Haofeng and Zhang, Pengle and Xiang, Chendong and Zhu, Jun and Chen, Jianfei},
  journal={arXiv preprint arXiv:2505.21136},
  year={2025}
}

许可证

本项目基于 Apache License 2.0 许可。

关于

[ICLR2025, ICML2025, NeurIPS2025 Spotlight] Quantized Attention achieves speedup of 2-5x compared to FlashAttention, without losing end-to-end metrics across language, image, and video models.

53.8 MB
邀请码