Skip to content

0x522 Jax

Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.

Jax's execution path is roughly

\[\text{python} \to[\text{trace}]\to \text{jaxpr} \to[\text{lower}]\to \text{StableHLO} \to[\text{xla compile}]\to \text{native}\]

1. API

1.1. lax

lax is a lib of primitive operations underlying jnp, it is a thin wrapper over XLA and we know how to transform

See this section about lax control flow

1.2. primitive

we can define new primitives and implement its evaluation, transformation methods

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

def multiply_add_prim(x, y, z):
  return multiply_add_p.bind(x, y, z)

1.2.1. eval

primal eval should be implemented

def multiply_add_impl(x, y, z):
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)

To enable tracing, abstract eval should be implemented to inform how to infer shape and type

def multiply_add_abstract_eval(xs, ys, zs):
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

# with abstract evaluation, we can tracing it correctly
# { lambda ; a:f32[] b:f32[] c:f32[]. let d:f32[] = multiply_add a b c in (d,) }
api.make_jaxpr(multiply_add_prim)(2.0, 3.0, 10.)

1.2.2. differentiation

define jvp and transpose

1.3. jnp

similar to the numpy syntax

1.3.1. async dispatch

jnp operations returns jax.Array, which is a future without waiting for computation to be completed. This causes difference for benchmarks and

%time jnp.dot(x, x)  
CPU times: user 267 µs, sys: 93 µs, total: 360 µs

%time np.asarray(jnp.dot(x, x))  
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms

See this doc for more details

1.3.2. memory management

There are a few flags controling HBM allocation, see this doc

To delete array buffer stored in HBM, use the following snippets

for x in jax.live_arrays():
  x.delete()

2. Trace

Different transformation requires different level of abstraction. They might have some issues depending on the control flow, see doc here

Target function should contain args/return of arrays, scalars or nested containers. Other values (e.g. str) will throw error during tracing

Note that tracing might happen by just using jnp as some jnp is implemented with jit, for example,

x = jnp.zeros((3,3))
# WARNING:jax._src.dispatch:Finished tracing + transforming broadcast_in_dim for pjit in 0.003103494644165039 sec
# WARNING:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
#WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.005272388458251953 sec
#WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.016413211822509766 sec

2.1. signature

Traced signature can be found using in_avals of lowered function. Signature should be pytree compatible.

def f(x, y):
  # during tracing, this shows 1 as it is static
  return jnp.sin(x['a']+y)

# (({'a': ShapedArray(int32[], weak_type=True)}, ShapedArray(float32[2,3])), {})
jax.jit(f).lower({'a': 4}, jnp.zeros((2,3))).in_avals

By default, tracing is done using ShapedArray abstraction level, which might be problematic when hitting control flow

Using static_argnums and static_argnames force tracer to trace with concrete value instead of ShapedArray they are considered compile-time constant, which will be constant-folded during tracing.

Changing those statics will trigger recompilation as it changes signature. Whether statics have changed or not depend on the __hash__ and __eq__ (statics must be hashable)

@functools.partial(jax.jit, static_argnums=1)
def f(x, a):

  # during tracing, this shows 1 as it is static
  print(a) 
  return jnp.sin(x)

f(jnp.zeros(1), 1) # this will compile
f(jnp.zeros(1), 2) # this will compile again as it is traced with actual number and signature change (from 1 to 2)

@jax.jit
def g(x, a):

  # in tracing, Traced<ShapedArray(int32[], weak_type=True)
  print(a)
  return jnp.sin(x)

g(jnp.zeros(1), 1) # this will compile
g(jnp.zeros(1), 2) # this will not re-compile as its traced with ShapedArray(int32[]) and its signature does not change

2.2. jaxpr

Jaxpr is a sequence of primitives equations which are obtained by tracing, and this tracing process can be inspected using make_jaxpr

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1

#invars: [a]
#outvars: [b]
#constvars: []
#equation: [a, 1] add [b] {}
#jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
examine_jaxpr(jax.make_jaxpr(foo)(5))


#Another approach to get jaxpr is to use traced stage
jax.jit(foo).trace(5).jaxpr

2.3. shape inference

A simple relevant untility is shape inference without doing actual FLOPs

import jax
import jax.numpy as jnp

f = lambda A, x: jnp.tanh(jnp.dot(A, x))
A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
out = jax.eval_shape(f, A, x)  # no FLOPs performed
print(out.shape)
(2000, 1000)
print(out.dtype)
float32

3. Transform

See this doc to understand how to add a new transformation

many Jax transformation are probably done on top of jaxpr, it (mostly) transforms jaxpr into another jaxpr and works as an interpreter.

3.1. grad

Here is a great blog comparing static/dynamic autodifferentiation library

jax uses jvp to obtain a forward jaxpr, then applying transpose to each expression of the forward jaxpr to obtain grad

Consider the function \(f(x,y) = xy+y\), (during grad) jvp produces the following expression where xt, yt are traced abstractly

a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt

transpose then interpets them from backward and produce the following expressions:

# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.

4. Compile

4.1. lower

Lowered stage is after the Traced stage, where traced jaxpr can be lowered into StableHLO in the MHLO form (mlir-HLO). Notice this is only translation, not optimization.

@jax.jit
def one_plus_one():
  a = jnp.ones(1)
  b = jnp.ones(1)
  return a + b

print(one_plus_one.lower().as_text())

example output is

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %4 = stablehlo.add %1, %3 : tensor<1xf32>
    return %4 : tensor<1xf32>
  }
}

lowering is achieved by implementing lowering rule for each Jaxpr primitive. this conversion is implemented here

def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

4.2. compile

Lowered stage can be futher compiled into Compiled stage, which is implemented by XlaExecutable

Most of the optimization happen at this stage.

compiled = jax.jit(lambda x: x + 2).lower(3).compile()


# execute this
compiled(2)
# Array(4, dtype=int32, weak_type=True)

# some analysis
compiled.cost_analysis()
compiled.memory_analysis()

unfortunately as_text as follows does not show executable binary directly, but the HLO before native binary compilation. Notice that const 1 + 1 is optimized into 2 directly

HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[1]{0:T(256)}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.2 () -> f32[1] {
  %constant.1 = f32[1]{0:T(256)} constant({2})
  ROOT %copy.1 = f32[1]{0:T(256)} copy(f32[1]{0:T(256)} %constant.1)
}

Another similar API is jax.xla_computation achieving similar compile results.

To see the actual? executable, use the following

from jax.experimental import serialize_executable
serialize_executable.serialize(compiled)

Internally, it will use xla_client to compile

4.3. cache

compiled cache stored in memory can be deleted by

jax.clear_caches()

5. Distribution

5.1. sharding

See jax distribution doc, jax.sharding doc, also check this kaggle blog

5.1.1. Mesh and PartitionSpec

There are a few Sharding, typically a XLACompatibleSharding. It only specify how data get sharded, but it does not represent the sharded data itself.

SingleDeviceSharding is a sharding that places its data on a single device (default sharding)

# tensor before placement is assigned to this default sharding
# it is placed on one of the device
x = jnp.zeros((8,16))

# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding

NamedSharding is a tuple of Mesh and ParitionSpec, it can be created manually as follows:

from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.experimental import mesh_utils

# create a mesh using either of the following
devices = np.array(jax.devices()).reshape(4, 2)
devices = mesh_utils.create_device_mesh((4, 2))

# Declare a 2D mesh with axes `x` and `y`.
mesh = Mesh(devices, ('x', 'y'))

# partitionspec describes how each input dim can be shareded using mesh axis
# e.g. PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh
# its argc should be equal to the tensor dim size which it is going to shard
spec = PartitionSpec('x', 'y')

# combine mesh and partitionspec together into a namedsharding
named_sharding = jax.sharding.NamedSharding(mesh, spec)

# shard the data and visualize
data = jnp.arange(8).reshape(2,4)
sharded_data = jax.device_put(data, named_sharding)

# this sharding info can be found in the array's attribute
print(sharded_data.sharding)
jax.debug.visualize_array_sharding(sharded_data)

┌──────────┬──────────┬──────────┬──────────┐
                                        
  TPU 0     TPU 1     TPU 2     TPU 3   
                                        
                                        
├──────────┼──────────┼──────────┼──────────┤
                                        
  TPU 4     TPU 5     TPU 6     TPU 7   
                                        
                                        
└──────────┴──────────┴──────────┴──────────┘

# another sharding
named_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
sharded_data = jax.device_put(data, named_sharding)
jax.debug.visualize_array_sharding(sharded_data)

┌────────────────────────────────────────────────┐
                                                
                  TPU 0,1,2,3                   
                                                
                                                
├────────────────────────────────────────────────┤
                                                
                  TPU 4,5,6,7                   
                                                
                                                
└────────────────────────────────────────────────┘

5.1.2. Placement

To control the device placement,

For explicit placement of variables, use jax.device_put to actually shard data with a given Sharding

For shardings in function, use

  • in_shardings and out_shardings to control explicitly inputs and outputs
  • jax.lax.with_sharding_constraint to force the placement of intermediate result within jit decorated function

If in_shardings and inputs' sharding is empty, it will be automatically computed by allowing spmd to be propagated back to the input parameters. This behavior is controlled by allow_spmd_sharding_propagation_to_parameters

mesh = Mesh(np.array(jax.devices()).reshape(2, 1), ('x', 'y'))
spec = P('x', 'y')
named_sharding = jax.sharding.NamedSharding(mesh, spec)

data = jnp.array([[1,2],[3,4]])

# Array([[1, 2],[3, 4]], dtype=int32)
sharded_data = jax.device_put(data, named_sharding)

# NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=tpu_hbm)
print(sharded_data.sharding)

# resharding using another spec
newspec = P('y', 'x')
new_sharding = jax.sharding.NamedSharding(mesh, newspec)
new_sharded_data = jax.device_put(sharded_data, new_sharding)

# resharding in jit decoration
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y

# use in_shardings and out_shardings
jit_f = jax.jit(f, in_shardings=(named_sharding,), out_shardings=named_sharding).lower(x, a).compile()

5.1.3. transfer guard

jax may transfer data between hosts and devices by

  • explicit transfer: jax.device_put(), jax.device_get()
  • implicit transfer: e.g. print DeviceArray

A transfer guard can be controlled as documented here

5.2. vmap and pmap

vmap vectorize map, for example

import jax.numpy as jnp

vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

pmap is the user-based SPMD programming

Applying pmap to a function will compile the function with XLA, then replicates the function to executes each replica on its own XLA device in parallel.

# the splitting dim should not be larger than n_devices
n_devices = jax.local_device_count()

# create a sharded array
a = jax.pmap(lambda x: x**2)(jnp.arange(n_devices))

# shardingspec
# PmapSharding(sharding_spec=ShardingSpec((Unstacked(2),), (ShardedAxis(axis=0),)), device_ids=[0, 1], device_platform=TPU, device_shape=(2,))
print(a.sharding)

# get info of each shard
jax.debug.visualize_array_sharding(a)
for i, shard in enumerate(a.global_shards):
    print(f"\nShard no: {i:>5}")
    print(f"Device: {str(shard.device):>32}")
    print(f"Data shape: {str(shard.data.shape):>8}")
    print(f"Data slices: {str(shard.index):>22}\n")
    print("="*75)
    print("")

5.3. jit (pjit)

jit is the Automatic SPMD programming where internal sharding is automatically propagated based on input/output sharding

The syntax is as follows:

pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue,...)

  • in_shardings: match argument structure, optional as jax will infer the sharding from the input, when inference cannot be done, it will be default to replication (i.e. None)
  • out_shardings: also optional, when not specified, sharding will be inferred from GSPMD's sharding propagation.
mesh = Mesh(devices, ('x', 'y'))
named_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('x', 'y'))

def f(x):
  return x + 1

jit_f = jax.jit(f, in_shardings=[named_sharding])

out = jit_f(jnp.arange(8).reshape((2,4)))

# sharding get propagated to:
# NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y'))
out.sharding

5.4. examples

See the following snippet for distributed computing examples

def mul(x, a):
  x = jnp.dot(x,a)
  return x

# single TPU core inference

x = jax.random.normal(jax.random.key(0), (8192, 8192))
a = jax.random.normal(jax.random.key(1), (8192, 8192))

# both x, a are placed on TPU 0
# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding

# ┌───────────────────────┐
# │                       │
# │                       │
# │                       │
# │                       │
# │         TPU 0         │
# │                       │
# │                       │
# │                       │
# │                       │
# └───────────────────────┘
jax.debug.visualize_array_sharding(a)

# compile before benchmark
mul_f = jax.jit(mul).lower(x, a).compile()
# 5 loops, best of 5: 4.77 ms per loop
%timeit -n 5 -r 5 mul_f(x, a).block_until_ready()


# shard x wrt 1 dim
mesh = Mesh(np.array(jax.devices()), ('x',))
named_sharding = NamedSharding(mesh, PartitionSpec('x',))
# data comm takes time, do this ahead of actual computing
rx = jax.device_put(x, named_sharding)

# ┌───────────────────────┐
# │         TPU 0         │
# ├───────────────────────┤
# │         TPU 1         │
# ├───────────────────────┤
# │         TPU 2         │
# ├───────────────────────┤
# │         TPU 3         │
# └───────────────────────┘
jax.debug.visualize_array_sharding(rx)

# replicate x into all devices
replicated_sharding = NamedSharding(mesh, (None, None))
ra = jax.device_put(x, replicated_sharding)

# ┌───────────────────────┐
# │                       │
# │                       │
# │                       │
# │                       │
# │      TPU 0,1,2,3      │
# │                       │
# │                       │
# │                       │
# │                       │
# └───────────────────────┘
jax.debug.visualize_array_sharding(ra)

sharded_mul_f = jax.jit(mul).lower(rx, ra).compile()

# 5 loops, best of 5: 1.29 ms per loop
%timeit -n 5 -r 5 sharded_mul_f(rx, ra).block_until_ready()

6. Serialization

6.1. StableHLO

See tutorial here how to serialize Jax into HLO

# serialize
import jax
from jax.experimental import export
import jax.numpy as jnp
import numpy as np

def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes:
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()

# deserialize
from jax.lib import xla_client
mlir_module = xla_client._xla.mlir.deserialize_portable_artifact(mlir_bytes)

6.1.1. jax2tf

See jax2tf doc

from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

def f_jax(x):
  return jnp.sin(jnp.cos(x))

# jax2tf.convert is a higher-order function that returns a wrapped function with
# the same signature as your input function but accepting TensorFlow tensors (or
# variables) as input.
f_tf = jax2tf.convert(f_jax)

# For example you execute f_tf eagerly with valid TensorFlow inputs:
f_tf(np.random.random(...))

# Additionally you can use tools like `tf.function` to improve the execution
# time of your function, or to stage it out to a SavedModel:
f_tf_graph = tf.function(f_tf, autograph=False)

# the following one is the actual trigger that initialize tracing and lowering
# it builds StableHLO and encapsulate the StableHLO into tf.XlaCallModule ops. all StableHLO information is serialized into string as a module attribute of XlaCallModule node.
f_concrete = f_tf_graph.get_concrete_function(tf.constant(1.0))

# read XlaCallModule ops
f_concrete.graph.as_graph_def().node[2]

# decode bytes into readable format for debugging
mlir_bytes = concrete_f.graph.as_graph_def().node[6].attr['module'].s
print(jax.lib.xla_client._xla.mlir.deserialize_portable_artifact(mlir_bytes))

# depending on the sharding, this graph might also has a few XlaSharding ops, which can be decoded back as the following
sharding_bytes = concrete_f.graph.as_graph_def().node[4].attr['_XlaSharding'].s
xla_data_pb2.OpSharding.FromString(sharding_bytes)

Note shardings in jax.jit(fun, in_shardings=) get translated into tf.XlaSharding

6.2. Checkpoint

Flax has a serialization library to save params as follows:

from flax import serialization

# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

# load, note that  a template is needed to reconstruct from bytes
serialization.from_bytes(params, bytes_output)

7. flax

Since jax transformation is functional, linen is designed to do stateless functional transform as the following where \(v\) are variable collection, \(x,y\) are input output data.

\[v_{out}, y = f(v_{in}, x)\]

7.1. linen

flax.linen.Module is a dataclass class, it only has hyperparameter members,(i.e. parameters are not stored inside)

from flax import linen as nn
from typing import Tuple

class Module(nn.Module):
  # this declared hyperparameter is an instance variable and available through self.features immediately
  features: Tuple[int, ...] = (16, 4)

  # setup will not be called during init, it is called to bind scope EVERYTIME when calling init or apply
  def setup(self):
    # can only be accessed in apply, init. not directly from the object
    self.dense1 = nn.Dense(self.features[0])
    self.dense2 = nn.Dense(self.features[1])

  # it get called after binding with setup
  def __call__(self, x):
    return self.dense2(nn.relu(self.dense1(x)))

variable is a container of "variable collections"

{
  "params": {
    "Conv1": { "weight": ..., "bias": ... },
    "BatchNorm1": { "scale": ..., "mean": ... },
    "Conv2": {...}
  },
  "batch_stats": {
    "BatchNorm1": { "moving_mean": ..., "moving_average": ...}
  }
}

param is a special class of variable. It is immutable. See this post

# this will retrive the value (if exists) or initialize it and return it (if not exists)
p = self.param('param_name', init_fn, shape, dtype) # shape, dtype are intializer's args (together with pseudo key)

# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value

init/apply

init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params

See the lifecycle doc for more details

model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input

# eager init
params = model.init(key2, x)

# eager apply
# when mutable args exist, it returns a tuple of (out, muted_variables)
y = model.apply(params, x)

# jit apply
outs = jax.jit(model.apply)(params, x)

# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])

7.2. Rematerialization

rematerialization/checkpointing feature is supported in jax, see this doc

8. optax

Optimizer

usage example

import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)