JAX is a Python library for accelerator-oriented array computation and program transformation,
designed for high-performance numerical computing and large-scale machine learning.
JAX can automatically differentiate native
Python and NumPy functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of derivatives of
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
via jax.grad as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.
JAX uses XLA
to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators.
You can compile your own pure functions with jax.jit.
Compilation and automatic differentiation can be composed arbitrarily.
This is a research project, not an official Google product. Expect
sharp edges.
Please help by trying it out, reporting bugs,
and letting us know what you think!
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jax.jit(jax.grad(loss)) # compiled gradient evaluation function
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
At its core, JAX is an extensible system for transforming numerical functions.
Here are three: jax.grad, jax.jit, and jax.vmap.
Automatic differentiation with grad
Use jax.grad
to efficiently compute reverse-mode gradients:
import jax
import jax.numpy as jnp
def tanh(x):
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743
You can differentiate to any order with grad:
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# prints 0.62162673
You’re free to use differentiation with Python control flow:
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
Use XLA to compile your functions end-to-end with
jit,
used either as an @jit decorator or as a higher-order function.
import jax
import jax.numpy as jnp
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)
vmap maps
a function along array axes.
But instead of just looping over function applications, it pushes the loop down
onto the function’s primitive operations, e.g. turning matrix-vector multiplies into
matrix-matrix multiplies for better performance.
Using vmap can save you from having to carry around batch dimensions in your
code:
To scale your computations across thousands of devices, you can use any
composition of these:
Compiler-based automatic parallelization
where you program as if using a single global machine, and the compiler chooses
how to shard data and partition computation (with some user-provided constraints);
See the documentation
for information on alternative installation strategies. These include compiling
from source, installing with Docker, using other versions of CUDA, a
community-supported conda build, and answers to some frequently-asked questions.
Citing JAX
To cite this repository:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
In the above bibtex entry, names are in alphabetical order, the version number
is intended to be that from jax/version.py, and
the year corresponds to the project’s open-source release.
A nascent version of JAX, supporting only automatic differentiation and
compilation to XLA, was described in a paper that appeared at SysML
2018. We’re currently working on
covering JAX’s ideas and capabilities in a more comprehensive and up-to-date
paper.
Transformable numerical computing at scale
Transformations | Scaling | Install guide | Change logs | Reference docs
What is JAX?
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via
jax.gradas well as forward-mode differentiation, and the two can be composed arbitrarily to any order.JAX uses XLA to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. You can compile your own pure functions with
jax.jit. Compilation and automatic differentiation can be composed arbitrarily.Dig a little deeper, and you’ll see that JAX is really an extensible system for composable function transformations at scale.
This is a research project, not an official Google product. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
Contents
Transformations
At its core, JAX is an extensible system for transforming numerical functions. Here are three:
jax.grad,jax.jit, andjax.vmap.Automatic differentiation with
gradUse
jax.gradto efficiently compute reverse-mode gradients:You can differentiate to any order with
grad:You’re free to use differentiation with Python control flow:
See the JAX Autodiff Cookbook and the reference docs on automatic differentiation for more.
Compilation with
jitUse XLA to compile your functions end-to-end with
jit, used either as an@jitdecorator or as a higher-order function.Using
jax.jitconstrains the kind of Python control flow the function can use; see the tutorial on Control Flow and Logical Operators with JIT for more.Auto-vectorization with
vmapvmapmaps a function along array axes. But instead of just looping over function applications, it pushes the loop down onto the function’s primitive operations, e.g. turning matrix-vector multiplies into matrix-matrix multiplies for better performance.Using
vmapcan save you from having to carry around batch dimensions in your code:By composing
jax.vmapwithjax.gradandjax.jit, we can get efficient Jacobian matrices, or per-example gradients:Scaling
To scale your computations across thousands of devices, you can use any composition of these:
jax.typeof;See the tutorial and advanced guides for more.
Gotchas and sharp bits
See the Gotchas Notebook.
Installation
Supported platforms
Instructions
pip install -U jaxpip install -U "jax[cuda13]"pip install -U "jax[tpu]"See the documentation for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions.
Citing JAX
To cite this repository:
In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from jax/version.py, and the year corresponds to the project’s open-source release.
A nascent version of JAX, supporting only automatic differentiation and compilation to XLA, was described in a paper that appeared at SysML 2018. We’re currently working on covering JAX’s ideas and capabilities in a more comprehensive and up-to-date paper.
Reference documentation
For details about the JAX API, see the reference documentation.
For getting started as a JAX developer, see the developer documentation.