A long-lived development branch to build an experimental CPU backend for Triton.
This repository clones the main Triton repository, but we intend to minimize
divergences in the core (and ideally upstream anything that needs to change and
isn’t too CPU-specific). Most of the CPU work should be in a backend
subdirectory (similar to how GPU vendors are supported today). We’re starting
with a clone to give ourselves maximum development flexibility as this project
gets off the ground!
How to use it?
Build it like a normal Triton, but just pass TRITON_CPU_BACKEND=1 to use the CPU backend over a GPU backend, if any.
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The official documentation contains installation instructions and tutorials. See also these third-party Triton puzzles, which can all be run using the Triton interpreter – no GPU required.
Quick Installation
You can install the latest stable release of Triton from pip:
pip install triton
Binary wheels are available for CPython 3.9-3.13.
Enabling Blackwell Support
The main branch now features support for NVIDIA Blackwell GPUs using 5th
generation tensor cores. To enable this, you will need two additional steps:
Build a pre-release PyTorch from source with CUDA 12.8
# Clone and checkout pytorch 2.6 release candidate
git clone https://github.com/pytorch/pytorch
cd pytorch
git checkout v2.6.0-rc9
git submodule sync
git submodule update --init --recursive -j 8
# Install build dependencies (assumes you already have a system compiler)
pip install -r requirements.txt
pip install mkl-static mkl-include wheel
# Build PyTorch (will take a long time)
export CUDA_HOME=/usr/local/cuda-12.8
export CUDA_PATH=$CUDA_HOME
export TORCH_CUDA_ARCH_LIST=Blackwell
python setup.py develop
# Optional, package build into a wheel to install on other machines.
python setup.py bdist_wheel
ls dist # Wheel should be output in this directory
Note that if you use the domain libraries (torchvision, torchtext,
torchaudio, etc.) these will need to be built from source as well, otherwise
their custom PyTorch extensions will not work.
Finally, follow the instructions below to install triton from source.
Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build
downloads a prebuilt LLVM, but you can also build LLVM from source and use that.
LLVM does not have a stable API, so the Triton build will not work at an
arbitrary LLVM version.
Find the version of LLVM that Triton builds against. Check
cmake/llvm-hash.txt to see the current version. For example, if it says:
49af6502c6dcb4a7f7520178bd14df396f78240c
This means that the version of Triton you have builds against
LLVM 49af6502.
git checkout LLVM at this revision. Optionally, make additional
modifications to LLVM.
$ cd $HOME/llvm-project # your clone of LLVM.
$ mkdir build
$ cd build
$ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm;lld" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
$ ninja
Grab a snack, this will take a while.
Build Triton as above, but set the following environment variables.
# Modify as appropriate to point to your LLVM build.
$ export LLVM_BUILD_DIR=$HOME/llvm-project/build
$ cd <triton install>
$ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \
LLVM_SYSPATH=$LLVM_BUILD_DIR \
pip install -e python
Tips for building
Set TRITON_BUILD_WITH_CLANG_LLD=true as an environment variable to use clang
and lld. lld in particular results in faster builds.
Set TRITON_BUILD_WITH_CCACHE=true to build with ccache.
Set TRITON_HOME=/some/path to change the location of the .triton
directory where Triton’s cache is located and downloads are stored
during the build. By default, this is the user’s home directory. It
can be changed anytime.
If you’re running out of memory when building Triton, specify the MAX_JOBS
environment variable (to the pip install -e python command) to limit the
number of jobs.
Pass --no-build-isolation to pip install to make nop builds faster.
Without this, every invocation of pip install uses a different symlink to
cmake, and this forces ninja to rebuild most of the .a files.
vscode intellisense has some difficulty figuring out how to build Triton’s C++
(probably because, in our build, users don’t invoke cmake directly, but
instead use setup.py). Teach vscode how to compile Triton as follows.
Do a local build. Run command pip install -e python
Get the full path to the compile_commands.json file produced by the build:
find python/build -name 'compile_commands.json' | xargs readlink -f.
You might get a full path similar to /Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json
In vscode, install the
C/C++
extension,
then open the command palette (Shift + Command + P on Mac, or Shift + Ctrl + P on Windows/Linux) and open C/C++: Edit Configurations (UI).
Open “Advanced Settings” and paste the full path to
compile_commands.json into the “Compile Commands” textbox.
Running tests
There currently isn’t a turnkey way to run all the Triton tests, but you can
follow the following recipe.
# One-time setup. Note this will reinstall local Triton because torch
# overwrites it with the public version.
$ make dev-install
# To run all tests (requires a GPU)
$ make test
# Or, to run tests without a gpu
$ make test-nogpu
Tips for hacking
For detailed instructions on how to debug Triton’s frontend, please refer to this tutorial. The following includes additional tips for hacking on Triton’s backend.
Helpful environment variables
MLIR_ENABLE_DUMP=1 dumps the IR before every MLIR pass Triton runs, for all
kernels. Use MLIR_ENABLE_DUMP=kernelName to dump for a specific kernel only.
Triton cache can interfere with the dump. In cases where MLIR_ENABLE_DUMP=1 does not work, try cleaning your triton cache: rm -r ~/.triton/cache/*
MLIR_DUMP_PATH specifies where MLIR_ENABLE_DUMP will dump to. If unset will dump to stderr.
LLVM_IR_ENABLE_DUMP=1 dumps the IR before every pass run over the LLVM IR.
TRITON_REPRODUCER_PATH=<reproducer_path> will generate an MLIR reproducer file
at <reproducer_path> before each MLIR compiler stage. If any of the stages fail,
<reproducer_path> will be a local MLIR reproducer captured right before the failing pass.
TRITON_INTERPRET=1 uses the Triton interpreter instead of running on the
GPU. You can insert Python breakpoints in your kernel code!
TRITON_ENABLE_LLVM_DEBUG=1 passes -debug to LLVM, printing a lot of
debugging information to stdout. If this is too noisy, run with just
TRITON_LLVM_DEBUG_ONLY instead to limit the output.
An alternative way to reduce output noisiness is running with
LLVM_IR_ENABLE_DUMP=1, extract the IR before the LLVM pass of interest, and
then run LLVM’s opt standalone, perhaps passing -debug-only=foo on the
command line.
TRITON_LLVM_DEBUG_ONLY=<comma-separated> is the equivalent of LLVM’s
-debug-only command-line option. This limits the LLVM debug output to
specific pass or component names (which are specified using #define DEBUG_TYPE throughout LLVM and Triton) in order to allow the debug output to
be less noisy. TRITON_LLVM_DEBUG_ONLY allows for one or more comma
separated values to be specified (eg
TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions" or
TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc").
TRITON_ENABLE_ASAN=1 invokes the LLVM address sanitizer for
memory leak and out of bounds access detection. Currently only supported on the AMD
backend. This must be run using the ASAN libraries documented here.
When enabling the address sanitizer it is recommended to disable various memory caching strategies
both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the
memory fault where it originates. See this test for more details.
USE_IR_LOC={ttir,ttgir} reparses the IR such that the location information
will be the line number of the IR file with that particular extension,
instead of line number of the python file. This can provide a direct mapping
from the IR to llir/ptx. When used with performance tools, it can provide a
breakdown on IR instructions.
TRITON_PRINT_AUTOTUNING=1 prints out the best autotuning config and total time
spent for each kernel after autotuning is complete.
DISABLE_LLVM_OPT will disable llvm optimizations for make_llir and make_ptx
if its value is true when parsing as Bool. Otherwise, it will be parsed as a list
of flags to disable llvm optimizations. One usage case is
DISABLE_LLVM_OPT="disable-lsr"
Loop strength reduction is known to cause up to 10% performance changes for
certain kernels with register pressure.
TRITON_ALWAYS_COMPILE=1 forces to compile kernels regardless of cache hit.
MLIR_ENABLE_TIMING dumps the timing information for each MLIR pass.
LLVM_ENABLE_TIMING dumps the timing information for each LLVM pass.
TRITON_DEFAULT_FP_FUSION overrides the default behavior of allowing fp fusion (mul+add->fma).
MLIR_ENABLE_DIAGNOSTICS=<comma-separated> controls diagnostic emission in MLIR.
Options are: warnings, remarks, stacktraces, operations.
Use comma-separated values to customize output. For example,
MLIR_ENABLE_DIAGNOSTICS=remarks,operations enables remarks and IR operations,
while MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces enables warnings with
stacktraces. By default, only errors are shown. Setting warnings includes
errors and warnings; remarks includes errors, warnings, and remarks.
MLIR_ENABLE_REMARK is deprecated. Please use MLIR_ENABLE_DIAGNOSTICS=remarks.
TRITON_KERNEL_DUMP enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
TRITON_DUMP_DIR specifies the directory to save the dumped IR and ptx/amdgcn when TRITON_KERNEL_DUMP is set to 1.
TRITON_KERNEL_OVERRIDE enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
TRITON_OVERRIDE_DIR specifies the directory from which to load the IR/ptx/amdgcn files when TRITON_KERNEL_OVERRIDE is set to 1.
TRITON_F32_DEFAULT sets the default input precision of tl.dot when using 32-bit floats, which can be either ieee, tf32, or tf32x3.
TRITON_FRONT_END_DEBUGGING=1 disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
Kernel Override Steps
export TRITON_ALWAYS_COMPILE=1
export TRITON_KERNEL_DUMP=1
export TRITON_DUMP_DIR=<dump_dir>
export TRITON_KERNEL_OVERRIDE=1
export TRITON_OVERRIDE_DIR=<override_dir>
# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR
# Step 2: Copy $TRITON_DUMP_DIR/<kernel_hash> to $TRITON_OVERRIDE_DIR
# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override
# Step 4: Run the kernel again to see the overridden result
Changelog
Version 2.0 is out! New features include:
Many, many bug fixes
Performance improvements
Backend rewritten to use MLIR
Support for kernels that contain back-to-back matmuls (e.g., flash attention)
Contributing
Community contributions are more than welcome, whether it be to fix bugs or to add new features at github. For more detailed instructions, please visit our contributor’s guide.
Triton-CPU
A long-lived development branch to build an experimental CPU backend for Triton.
This repository clones the main Triton repository, but we intend to minimize divergences in the core (and ideally upstream anything that needs to change and isn’t too CPU-specific). Most of the CPU work should be in a backend subdirectory (similar to how GPU vendors are supported today). We’re starting with a clone to give ourselves maximum development flexibility as this project gets off the ground!
How to use it?
Build it like a normal Triton, but just pass TRITON_CPU_BACKEND=1 to use the CPU backend over a GPU backend, if any.
NOTE: It’s still work in progress.
Upstream README
DocumentationNightly WheelsTriton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The foundations of this project are described in the following MAPL2019 publication: Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. Please consider citing this work if you use Triton!
The official documentation contains installation instructions and tutorials. See also these third-party Triton puzzles, which can all be run using the Triton interpreter – no GPU required.
Quick Installation
You can install the latest stable release of Triton from pip:
Binary wheels are available for CPython 3.9-3.13.
Enabling Blackwell Support
The main branch now features support for NVIDIA Blackwell GPUs using 5th generation tensor cores. To enable this, you will need two additional steps:
First, to build pytorch you need to have CUDA 12.8 installed locally. If not, follow the instructions for your platform
Note that if you use the domain libraries (
torchvision,torchtext,torchaudio, etc.) these will need to be built from source as well, otherwise their custom PyTorch extensions will not work.Finally, follow the instructions below to install triton from source.
Install from source
Or with a virtualenv:
Building with a custom LLVM
Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build downloads a prebuilt LLVM, but you can also build LLVM from source and use that.
LLVM does not have a stable API, so the Triton build will not work at an arbitrary LLVM version.
Find the version of LLVM that Triton builds against. Check
cmake/llvm-hash.txtto see the current version. For example, if it says:This means that the version of Triton you have builds against LLVM 49af6502.
git checkoutLLVM at this revision. Optionally, make additional modifications to LLVM.Build LLVM. For example, you might run
Grab a snack, this will take a while.
Build Triton as above, but set the following environment variables.
Tips for building
Set
TRITON_BUILD_WITH_CLANG_LLD=trueas an environment variable to use clang and lld. lld in particular results in faster builds.Set
TRITON_BUILD_WITH_CCACHE=trueto build with ccache.Set
TRITON_HOME=/some/pathto change the location of the.tritondirectory where Triton’s cache is located and downloads are stored during the build. By default, this is the user’s home directory. It can be changed anytime.If you’re running out of memory when building Triton, specify the
MAX_JOBSenvironment variable (to thepip install -e pythoncommand) to limit the number of jobs.Pass
--no-build-isolationtopip installto make nop builds faster. Without this, every invocation ofpip installuses a different symlink to cmake, and this forces ninja to rebuild most of the.afiles.vscode intellisense has some difficulty figuring out how to build Triton’s C++ (probably because, in our build, users don’t invoke cmake directly, but instead use setup.py). Teach vscode how to compile Triton as follows.
pip install -e pythoncompile_commands.jsonfile produced by the build:find python/build -name 'compile_commands.json' | xargs readlink -f. You might get a full path similar to/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.jsonShift + Command + Pon Mac, orShift + Ctrl + Pon Windows/Linux) and openC/C++: Edit Configurations (UI).compile_commands.jsoninto the “Compile Commands” textbox.Running tests
There currently isn’t a turnkey way to run all the Triton tests, but you can follow the following recipe.
Tips for hacking
For detailed instructions on how to debug Triton’s frontend, please refer to this tutorial. The following includes additional tips for hacking on Triton’s backend.
Helpful environment variables
MLIR_ENABLE_DUMP=1dumps the IR before every MLIR pass Triton runs, for all kernels. UseMLIR_ENABLE_DUMP=kernelNameto dump for a specific kernel only.MLIR_ENABLE_DUMP=1does not work, try cleaning your triton cache:rm -r ~/.triton/cache/*MLIR_DUMP_PATHspecifies whereMLIR_ENABLE_DUMPwill dump to. If unset will dump to stderr.LLVM_IR_ENABLE_DUMP=1dumps the IR before every pass run over the LLVM IR.TRITON_REPRODUCER_PATH=<reproducer_path>will generate an MLIR reproducer file at<reproducer_path>before each MLIR compiler stage. If any of the stages fail,<reproducer_path>will be a local MLIR reproducer captured right before the failing pass.TRITON_INTERPRET=1uses the Triton interpreter instead of running on the GPU. You can insert Python breakpoints in your kernel code!TRITON_ENABLE_LLVM_DEBUG=1passes-debugto LLVM, printing a lot of debugging information to stdout. If this is too noisy, run with justTRITON_LLVM_DEBUG_ONLYinstead to limit the output.An alternative way to reduce output noisiness is running with
LLVM_IR_ENABLE_DUMP=1, extract the IR before the LLVM pass of interest, and then run LLVM’soptstandalone, perhaps passing-debug-only=fooon the command line.TRITON_LLVM_DEBUG_ONLY=<comma-separated>is the equivalent of LLVM’s-debug-onlycommand-line option. This limits the LLVM debug output to specific pass or component names (which are specified using#define DEBUG_TYPEthroughout LLVM and Triton) in order to allow the debug output to be less noisy.TRITON_LLVM_DEBUG_ONLYallows for one or more comma separated values to be specified (egTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions"orTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc").TRITON_ENABLE_ASAN=1invokes the LLVM address sanitizer for memory leak and out of bounds access detection. Currently only supported on the AMD backend. This must be run using the ASAN libraries documented here.When enabling the address sanitizer it is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the memory fault where it originates. See this test for more details.
USE_IR_LOC={ttir,ttgir}reparses the IR such that the location information will be the line number of the IR file with that particular extension, instead of line number of the python file. This can provide a direct mapping from the IR to llir/ptx. When used with performance tools, it can provide a breakdown on IR instructions.TRITON_PRINT_AUTOTUNING=1prints out the best autotuning config and total time spent for each kernel after autotuning is complete.DISABLE_LLVM_OPTwill disable llvm optimizations for make_llir and make_ptx if its value is true when parsing as Bool. Otherwise, it will be parsed as a list of flags to disable llvm optimizations. One usage case isDISABLE_LLVM_OPT="disable-lsr"Loop strength reduction is known to cause up to 10% performance changes for certain kernels with register pressure.TRITON_ALWAYS_COMPILE=1forces to compile kernels regardless of cache hit.MLIR_ENABLE_TIMINGdumps the timing information for each MLIR pass.LLVM_ENABLE_TIMINGdumps the timing information for each LLVM pass.TRITON_DEFAULT_FP_FUSIONoverrides the default behavior of allowing fp fusion (mul+add->fma).MLIR_ENABLE_DIAGNOSTICS=<comma-separated>controls diagnostic emission in MLIR. Options are:warnings,remarks,stacktraces,operations. Use comma-separated values to customize output. For example,MLIR_ENABLE_DIAGNOSTICS=remarks,operationsenables remarks and IR operations, whileMLIR_ENABLE_DIAGNOSTICS=warnings,stacktracesenables warnings with stacktraces. By default, only errors are shown. Settingwarningsincludes errors and warnings;remarksincludes errors, warnings, and remarks.MLIR_ENABLE_REMARKis deprecated. Please useMLIR_ENABLE_DIAGNOSTICS=remarks.TRITON_KERNEL_DUMPenables the dumping of the IR from each compilation stage and the final ptx/amdgcn.TRITON_DUMP_DIRspecifies the directory to save the dumped IR and ptx/amdgcn whenTRITON_KERNEL_DUMPis set to 1.TRITON_KERNEL_OVERRIDEenables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.TRITON_OVERRIDE_DIRspecifies the directory from which to load the IR/ptx/amdgcn files whenTRITON_KERNEL_OVERRIDEis set to 1.TRITON_F32_DEFAULTsets the default input precision oftl.dotwhen using 32-bit floats, which can be eitherieee,tf32, ortf32x3.TRITON_FRONT_END_DEBUGGING=1disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.Kernel Override Steps
Changelog
Version 2.0 is out! New features include:
Contributing
Community contributions are more than welcome, whether it be to fix bugs or to add new features at github. For more detailed instructions, please visit our contributor’s guide.
Compatibility
Supported Platforms:
Supported Hardware:
Development Container (Dev Container)
Dev Containers for the Triton project are available from the triton-dev-containers repository
Key Benefits:
How to Use the Dev Container:
For detailed instructions on how to use the dev containers please see the dev container user guide