@inproceedings{zhang2025sageattention,
title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration},
author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}
@inproceedings{zhang2024sageattention2,
title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization},
author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Machine Learning (ICML)},
year={2025}
}
@article{zhang2025sageattention2++,
title={Sageattention2++: A more efficient implementation of sageattention2},
author={Zhang, Jintao and Xu, Xiaoming and Wei, Jia and Huang, Haofeng and Zhang, Pengle and Xiang, Chendong and Zhu, Jun and Chen, Jianfei},
journal={arXiv preprint arXiv:2505.21136},
year={2025}
}
许可证
本项目基于 Apache License 2.0 许可。
关于
[ICLR2025, ICML2025, NeurIPS2025 Spotlight] Quantized Attention achieves speedup of 2-5x compared to FlashAttention, without losing end-to-end metrics across language, image, and video models.
SageAttention — MetaX MXMACA 平台移植版 (v3.0.0)
本项目是 SageAttention 从 NVIDIA CUDA → MetaX MXMACA 的完整平台移植,针对 MetaX C500/C550/C600 GPU 深度优化。
目录
项目背景与上游关系
上游 SageAttention 是一系列高效的量化注意力算子,包含:
本移植版基于 SageAttention2++ 架构(约 3,600 行 kernel + 调度层代码),覆盖 INT8 QK 量化 × FP16 PV 累加的全部 6 条计算路径,以及 per-thread / per-block 两种量化粒度。
不可移植的上游功能(受沐曦硬件限制)
mma.sync.m16n16k64.s4.s4.s32mma.sync.m16n16k32.f8.f8.f32(e4m3)wgmma.mma_async、TMA bulk copy、mbarriere4m3x4转换指令CUDA → MXMACA 功能对比矩阵
核心注意力管线对比
__nv_bfloat16)maca_bfloat16, 含软件转换路径)torch.library.custom_op+ fake tensor)底层技术与 ISA 对比
mma.syncPTX (直接)wmma::fragment+wmma::load_matrix_sync+wmma::mma_syncAPIldmatrix.sync.aligned.x4PTXwmma::load_matrix_sync()cp.async.ca.shared.globalPTX →commit_group/wait_group__builtin_mxc_load_global_async128→__builtin_mxc_arrive_gvmcnt(N)wgmma.mma_async+CUtensorMap(Hopper)dim3(32, num_warps)(32-thread warp)dim3(64, num_warps/2)(64-lane wave, 两个 MMA 子组)__shfl_xor_sync(0xffffffff, val, off, 32)__shfl_xor_sync(0xFFFFFFFFFFFFFFFFULL, val, off, 64)__reduce_add_sync/__reduce_max_sync(MACA §2.21 原生指令)__stcg缓存写__stcg,普通写)__exp2f/__log2f/__tanhf/hfma2__expf/__log2f(float),h2exp2/hexp2(half), tanh 为 float roundtrip-arch=sm_80/89/90)--offload-arch=xcore1000/xcore1500,-x maca)cudaSetDevice/cudaStream_t/cudaEvent_tmcSetDevice/mcStream_t/mcEvent_tcudaFuncSetAttribute(..., cudaFuncAttributeMaxDynamicSharedMemorySize, ...)mcFuncSetAttribute(..., mcFuncAttributeMaxDynamicSharedMemorySize, ...)at::kCUDA已移植功能详解
6 条计算路径
所有路径共享
int8 QK × fp16 SV基本架构,在累加器类型和平滑选项上区分:accum_f32accum_f32_fuse_v_meanaccum_f16_fuse_v_mean_inst_bufaccum_fp32(per_thread)accum_fp32(per_block)accum_fp16(InstBuf, per_thread)注意力 kernel (
csrc/qattn/qk_int_sv_f16_maca.cu, 1,343 行)计算流水线 (双缓冲异步 + 在线 softmax):
__builtin_mxc_load_global_async128将 INT8 Q 从 global → smem;若num_tiles_qk_inner=1则将全部 Q tile 预加载至寄存器wmma::load_matrix_sync从 XOR swizzled smem 加载 →wmma::mma_sync(m16n16k32, s8s8s32)→ INT32 中间结果sm_scale * log2(e)缩放 → 64-lane warp reduce 求 max →update_mdo(指数缩放 + 累加)wmma::load_matrix_sync加载 V →wmma::mma_sync(m16n16k16, f16f16f32)(或f16f16f16for InstBuf)RO_inst_buf[half]→ 逐元素 float 累加至 RO → 提升吞吐不损失精度d→normalize_d(用__frcp_rn近似倒数)half2→float2→bfloat162), V mean 加回, LSE 计算, swizzled smem→global 写回启动配置:
dim3(64, (CTA_Q/WARP_Q) × (CTA_K/WARP_K)) = dim3(64, 4×1) = 256 threads(div_ceil(qo_len, 128), num_qo_heads, batch_size)max(CTA_Q×HD×sizeof(int8) + 2×CTA_K×HD×sizeof(int8+half), CTA_Q×HD×sizeof(half))量化 kernel (
csrc/fused/fused_maca.cu, 932 行)QuantInt8Kernelfloat_to_int8_rn(x * 127/amax), 支持 sm_scale 预乘和 sub_mean 融合SubMeanKernel__hsub2(half2) 或 float roundtrip (bf16)TransposePadPermuteKernel(0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15)mcTriton 后端 (8 个 kernel)
全部 8 个 Triton attention/quantization kernel 已适配 mcTriton (Triton 3.0.0 + MXMACA backend):
适配要点:
pipeline="cpasync",scenario="flashattn-fwd"自动调优指令,pipeline="basic"for quant kernels,tl.interleave替代tl.cat(Triton 3.0 兼容性)。Python 调度层 (
sageattention/core.py, 644 行)MXMACA 架构优化
本移植版针对 MXMACA 硬件特点做了以下原生优化:
dim3(64, N/2), 每个 64-lane wave 并行驱动两个 32-lane WMMA 子组 — 相比朴素的dim3(32, N)映射提升 2× 吞吐csrc/utils.cuh:10-26wmma::fragment<...int8_t...>+wmma::mma_sync()(M16N16K32)csrc/mma_maca.cuhwmma::fragment<...half...>+wmma::mma_sync()(M16N16K16), m16n16k16 由两个 m16n8k16 拼接 (N 维拆分)csrc/mma_maca.cuh:139-166wmma::mma_sync()with__halfaccumulator — MACA §2.24.4 原生支持csrc/mma_maca.cuh:169-182__builtin_mxc_load_global_async128()128-bit 异步加载 →__builtin_mxc_arrive_gvmcnt(N)事务到达计数 → 双缓冲与 MMA 重叠csrc/cp_async_maca.cuhcsrc/permuted_smem_maca.cuhsm_scale *= log2(e), 使用__expf(float) /h2exp2/hexp2(half) 硬件指令csrc/math_maca.cuh__reduce_add_sync/__reduce_max_sync(§2.21) 替代 butterfly shufflecsrc/reduction_utils_maca.cuh:46-51迁移技术难点与解决方案
以下记录了从 NVIDIA CUDA 到 MXMACA 的完整迁移过程中遇到的关键技术挑战、分析过程和最终方案。
1. WMMA 接口层: PTX
mma.sync→ APIwmma::mma_sync挑战: 上游 NVIDIA kernel 使用 PTX 级
mma.sync.aligned.m16n16k32.row.col.s8.s8.s32等指令直接操作寄存器片段 (uint32_t[4]/uint32_t[8])。MXMACA 不暴露等价的 PTX 级 MMA 指令。分析过程: 查阅 MXMACA C++ 编程指南 §2.24,确认 MXMACA 提供
wmma::namespace API 作为标准 MMA 接口。需要将 PTX 的原始 uint32 寄存器片段映射到 WMMA fragment 抽象。解决方案 (在
csrc/mma_maca.cuh中实现):wmma::fragment<>对象extract_fragment_data(frag, uint32_array)进行 fragment → 原始寄存器的字节拷贝,保持与上游 kernel 的寄存器布局兼容mma_sync_m16n16k16_row_col_f16f16f32通过两次m16n8k16WMMA 调用拼接 N 维(与 NVIDIAmma.cuh中的实现方式相同)mma_sync_m16n8k16_f16f16f16路径经 MXMACA §2.24.4 确认原生支持mma.sync.m16n16k16.bf16.bf16.f32) 通过wmma::fragment+maca_bfloat16类型实现2. 矩阵加载: PTX
ldmatrix→wmma::load_matrix_sync挑战: NVIDIA kernel 大量使用
ldmatrix.sync.aligned.x4.trans.shared.b16PTX 指令从共享内存加载 16×16 矩阵片段。MXMACA 没有等价 PTX 指令。分析过程: 对比 MXMACA C++ 编程指南 §2.24 中的 WMMA API,
wmma::load_matrix_sync()是标准矩阵加载接口。但该 API 的正确使用前提是共享内存地址需满足 swizzle 对齐要求。解决方案:
wmma::load_matrix_sync(fragment, smem_ptr, ldm)替换所有ldmatrix调用swizzle_mode == k128B(static_assert保护 atmma_maca.cuh:55)ldmatrix_m8n8x4vsldmatrix_m8n8x4_trans显式转换需求load_matrix_sync接受ldm(leading dimension) 参数,不再需要手算 stride — 简化了共享内存访问逻辑3. 异步数据传输: PTX
cp.async→__builtin_mxc_load_global_async128挑战: NVIDIA kernel 的双缓冲异步流水线依赖
cp.async.ca.shared.globalPTX +commit_group/wait_group机制。MXMACA 提供不同的异步加载原语。分析过程: 研究了 MXMACA Runtime API 编程指南和头文件,发现 MXMACA 提供:
__builtin_mxc_load_global_async128— 128-bit 异步全局→寄存器加载__builtin_mxc_arrive_gvmcnt(N)— global memory transaction count 到达同步__builtin_mxc_arrive_bsmcnt(0)— shared memory (barrier) transaction count 到达同步需要注意 MXMACA 会自动追踪所有异步加载事务(无需显式 commit),这与 CUDA 的
cp.async.commit_group不同。解决方案 (在
csrc/cp_async_maca.cuh中实现):commit_group()→ 空操作 (MXMACA 硬件自动追踪)wait_group<n>()→__builtin_mxc_arrive_gvmcnt(n)(等待 ≤ n 个事务未完成)load_128b(smem, gmem)→__builtin_mxc_load_global_async128(gmem, smem)→ 写入uint4*pred_load_128b(smem, gmem, pred)→ 带零填充的谓词异步加载重要设计决策 (2026-06-24): 保留
__builtin_mxc_*底层原语而非迁移到cooperative_groups::memcpy_async,因为:wait(group)在计算前序列化所有加载 — 破坏双缓冲异步流水线namespace async已提供清晰抽象层4. 64-lane Wave vs 32-lane Warp
挑战: MXMACA 硬件 wave 为 64 lanes (
warpSize = 64, Runtime API §2.2),而 NVIDIA warp 为 32。所有依赖warpSize的代码(shuffle、reduction、lane 标识)需要重评估。分析过程: MXMACA 的 WMMA 在内部将 64-lane wave 拆分为两个 32-lane MMA 子组。因此 MMA 级别的
mma_lane_id可以保持 32-lane 语义,但 warp 级 reduction 必须扩展至 64 lanes。解决方案:
dim3(64, num_warps/2)替代dim3(32, num_warps)— 64 个 threadIdx.x 对应 64 lanes__shfl_xor_sync(kMacaFullMask=0xFFFFFFFFFFFFFFFFULL, val, offset, 64)替代warpSize=32__reduce_add_sync/__reduce_max_sync(§2.21) 用于 int32/uint32, 比手工 butterfly 更高效kMmaLaneGroupSize = 32, kMmaLaneMask = 0xFFFFFFFF用于 MMA 子组操作5. 共享内存 Bank Conflict: 软件 Swizzle 实现
挑战: NVIDIA 版本通过
permuted_smem.cuh中的 XOR swizzle 消除共享内存 bank conflict。MXMACA 需要等效的软件实现,且避免某些特殊指令。分析过程: XOR swizzle 的核心逻辑是
offset = row * stride + xor(col, row_mod_pattern)。跨架构可移植。主要风险是 swizzle 粒度选择 (k32B/k64B/k128B) 和 MXMACA 共享内存的 bank 结构匹配。解决方案 (在
csrc/permuted_smem_maca.cuh中实现):k32B(直接映射,stride=2),k64B(XOR col with(row/2)%4, stride=4),k128B(XOR col withrow%8, stride≥8)advance_offset_by_column<step>/advance_offset_by_row<step>正确处理 XOR 反转 (k128B 上步长 1/2/4 需要特殊 XOR 推进逻辑)advance_offset_*,load_128b_*,store_128b_*)__stcg(store cache-global): 在 MACA 上不可用 →store_128b_cg在__MACA__下回退为 plain store6. 数学函数适配
挑战: NVIDIA kernel 使用大量 half2/half 原生 PTX 数学指令。MXMACA 的函数库覆盖范围不同。
分析: 逐函数验证 MXMACA 支持情况。
__expf/__log2f(float)h2exp2/hexp2(half)__frcp_rn(float)__frsqrt_rn(float)__hadd2/__hmul2(half2)__hfma2(half2 FMA)__float2half2_rn__float22half2_rntanh(half/half2)tanhf()__exp2f(float)__exp2fintrinsic__expf(x * 0.6931471805599453f)等价替代7. BF16 支持:
maca_bfloat16类型系统挑战: NVIDIA 使用
__nv_bfloat16/__nv_bfloat162, MXMACA 使用maca_bfloat16/maca_bfloat162。转换语义和 WMMA 加载方式不同。解决方案:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16宏映射at::kBFloat16→maca_bfloat16(而非__nv_bfloat16)half2→float2(通过unpack_half2_from_uint32_to_float) →bfloat162(通过__float2bfloat16_rn)test_mxmac_api_validation.py中包含mma.sync.m16n16k16.bf16.bf16.f32正确性测试8. mxcc 编译器适配
挑战: nvcc → mxcc 的编译选项体系差异显著,需要全新的构建系统配置。
解决方案 (在
setup.py中实现):_detect_mxcc_arch()— 通过SAGEATTN_OFFLOAD_ARCH环境变量 /mxcc --list-offload-arch/maca_version.h解析三层回退机制,映射C500→xcore1000,C600→xcore1500-x maca(指定 MXMACA 语言模式),--offload-arch=xcore1000,-O2(从-O1升级于 2025-06-24),--maxrregcount=96,-use-fast-math$MACA_PATH/include,$MACA_PATH/mxgpu_llvm/include,$MACA_PATH/tools/cu-bridge/include-fgpu-rdc启用可重定位设备码(对应 nvcc 的-rdc=true)9. 量化粒度降级: per_warp → per_block
挑战: MACA WMMA 的
mma.syncAPI 不支持 warp 级子 tile 量化粒度 (在单个 warp 内进一步分块)。上游per_warp_int8依赖此特性。分析: 查阅
mma_maca.cuh中的 WMMA fragment 定义和 MXMACA WMMA 指令限制,确认 WMMA 的最小操作单元为 16×16 tile,在 warp 级 (32 MMA lanes) 无法再细分。解决方案 (在
sageattention/core.py:265中实现):qk_quant_gran="per_warp"时,Python 调度层自动降级为per_block_int8_cuda()per_warp_int8CUDA kernel 保留在fused_maca.cu中但未被调度器调用 — 作为未来硬件的参考实现10. Varlen 逐批次循环的性能限制
当前实现:
_sageattn_varlen_maca()对每个 batch 元素逐一切片、重排、调用单序列注意力 kernel。每个 batch 产生一次独立 kernel launch。已知性能瓶颈: Python 层 per-batch 排列开销 + per-batch kernel launch 延迟在大量小序列时累积显著。
记录为 TODO:
sageattention/core.py:514-522— 计划使用 MACA streams 实现 batch 间并行调度,或下沉为 batched varlen C++ kernel。沐曦硬件局限性分析
以下基于 MXMACA SDK 3.7.0.x 和 MXC500/C600 硬件规格进行的系统评估:
当前硬件限制
mma.sync.m16n16k64.s4.s4.s32mma.sync.m16n16k32.f8.f8.f32+ per_channel_fp8 量化wgmma.mma_asyncCUtensorMap+cp.async.bulk__stcg__stcgPTX__tanhfhalf2 intrinsictanhf()(精度无损, 延迟略增)性能影响评估
后续规划
P0 — 近期高优先级
fuse_v_mean路径。补充 InstBuf + 其他配置组合,使最高吞吐路径覆盖所有使用场景P1 — 性能调优
num_tiles_qk_inner == 1时 Q 已预加载至寄存器 — 探索对其他 head_dim/WARP 配置启用此路径P2 — 生态集成
FIXME(DefTruth)标注)example/目录)反馈沐曦
wmma::load_matrix_sync在多个 kernel 路径上引入了 fragment 构建/解构开销__stcg等价语义开发反馈
沐曦 MXMACA 生态体验总结
优势:
编程模型高度兼容: MXMACA 与 CUDA 的 Kernel 启动语法 (
<<<grid, block, shmem, stream>>>), Cooperative Groups, 内存模型几乎完全一致。kernel 层代码迁移的主要工作是指令替换而非算法重写,迁移效率显著。WMMA API 覆盖充分: MXMACA WMMA API 覆盖了 INT8 (
s8s8s32), FP16 (f16f16f32/f16f16f16), BF16 (bf16bf16f32) 三种常用 Tensor Core 类型。wmma::fragment+wmma::load_matrix_sync+wmma::mma_sync三层 API 设计清晰。mxcc 编译器成熟度: 多代硬件架构的自动检测 (
--offload-arch) 降低适配成本。--maxrregcount,-use-fast-math,-fgpu-rdc等优化选项与 nvcc 对应选项语义一致。maca_bfloat16 类型完备:
__bfloat162float,__float2bfloat16_rn,__bfloat162bfloat162等转换函数与 CUDA 等价函数一一对应。mcTriton 后端: Triton 3.0.0 + MXMACA 后端的
pipeline="cpasync"/scenario="flashattn-fwd"自动调优指令与 CUDA Triton 对齐,为 Python 层提供了有效的跨平台兼容路径。待改进:
INT4 WMMA 和 FP8 是补齐量化推理的关键瓶颈: 这两个数据类型支持是沐曦硬件在 AI 推理领域缩小与 NVIDIA 差距的最重要工程项。缺少它们意味着 SageAttention 的 INT4 QK 路径 (SageAttention2 核心特性) 和 FP8 PV 路径 (SageAttention2++ 最高吞吐) 完全不可用。
cp.async 流水线深度控制: 与 NVIDIA TMA 相比,
__builtin_mxc_load_global_async128的流水线深度控制粒度更粗 (只有 gvmcnt 计数,无可编程 commit group),影响了复杂流水线的微调能力。cucc 模板限制: cucc (MXMACA C++ 编译器) 无法解析模板成员函数调用,导致
permuted_smem_maca.cuh中所有smem_t的模板方法必须改为 free-standing template 函数。这引入了额外的代码重复和样板。共享内存 bank 信息: 缺少 MXMACA 共享内存 bank 结构 (bank 数、bank 宽度、bank conflict 检测工具) 的官方文档,导致 swizzle 粒度选择依赖经验性试错而非精确建模。
技术迁移统计
csrc/qattn/qk_int_sv_f16_maca.cu)csrc/fused/fused_maca.cu)core.py+maca_compile.py+metax_utils.py+quant.py)安装
前置依赖
mxcc编译器 (包含在 MXMACA SDK 中)环境配置
从源码编译
构建系统通过
mxcc --list-offload-arch自动检测 GPU 架构:xcore1000xcore1002xcore1500环境变量
MACA_PATH/opt/macaMXCCmxccSAGEATTN_OFFLOAD_ARCHSAGEATTN_SKIP_BUILD01跳过 C++ 扩展编译CXX_APPEND_FLAGS使用方法
基础调用
q, k, v为 FP16/BF16 张量。tensor_layout="HND"→ 形状(batch_size, head_num, seq_len, head_dim)。tensor_layout="NHD"→ 形状(batch_size, seq_len, head_num, head_dim)。可用 API
sageattnsageattn_qk_int8_pv_fp16_macasageattn_qk_int8_pv_fp16_tritonsageattn_varlen高级配置
6 条计算路径的参数组合:
pv_accum_dtypesmooth_v"fp32"False"fp32"True"fp16+fp16"True"fp16+fp16"False测试
完整测试套件 (45+ 用例)
性能基准
测试覆盖矩阵
性能说明
fp16+fp16(InstBuf) PV 累加路径提供最高吞吐量: FP16 Tensor Core 中间结果 → FP32 最终输出, 兼顾速度与精度。smooth_k) 通过消除 QK 离群值提升量化精度,V 均值融合 (smooth_v) 在核函数内完成均值的减和加,零额外 kernel launch 开销。__expf/h2exp2硬件指令避免 base-e 的expf调用,数值精度等价。dim3(64, N/2)) 相比朴素映射将 WMMA 利用率从 50% 提升至 ~100%。文档导航
torch.library.custom_op注册 + fake tensor (torch.compile 支持)引用
如果您使用了本代码或认为我们的工作有价值,请引用:
许可证
本项目基于 Apache License 2.0 许可。