0x522 Jax
Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.
Jax's execution path is roughly
1. API
1.1. lax
lax is a lib of primitive operations underlying jnp, it is a thin wrapper over XLA and we know how to transform
See this section about lax control flow
1.2. primitive
we can define new primitives and implement its evaluation, transformation methods
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
def multiply_add_prim(x, y, z):
return multiply_add_p.bind(x, y, z)
1.2.1. eval
primal eval should be implemented
def multiply_add_impl(x, y, z):
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
To enable tracing, abstract eval should be implemented to inform how to infer shape and type
def multiply_add_abstract_eval(xs, ys, zs):
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
# with abstract evaluation, we can tracing it correctly
# { lambda ; a:f32[] b:f32[] c:f32[]. let d:f32[] = multiply_add a b c in (d,) }
api.make_jaxpr(multiply_add_prim)(2.0, 3.0, 10.)
1.2.2. differentiation
define jvp and transpose
1.3. jnp
similar to the numpy syntax
1.3.1. async dispatch
jnp operations returns jax.Array
, which is a future without waiting for computation to be completed. This causes difference for benchmarks and
%time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
%time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
See this doc for more details
1.3.2. memory management
There are a few flags controling HBM allocation, see this doc
To delete array buffer stored in HBM, use the following snippets
for x in jax.live_arrays():
x.delete()
2. Trace
Different transformation requires different level of abstraction. They might have some issues depending on the control flow, see doc here
Target function should contain args/return of arrays, scalars or nested containers. Other values (e.g. str) will throw error during tracing
Note that tracing might happen by just using jnp
as some jnp is implemented with jit, for example,
x = jnp.zeros((3,3))
# WARNING:jax._src.dispatch:Finished tracing + transforming broadcast_in_dim for pjit in 0.003103494644165039 sec
# WARNING:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
#WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.005272388458251953 sec
#WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.016413211822509766 sec
2.1. signature
Traced signature can be found using in_avals
of lowered function. Signature should be pytree compatible.
def f(x, y):
# during tracing, this shows 1 as it is static
return jnp.sin(x['a']+y)
# (({'a': ShapedArray(int32[], weak_type=True)}, ShapedArray(float32[2,3])), {})
jax.jit(f).lower({'a': 4}, jnp.zeros((2,3))).in_avals
By default, tracing is done using ShapedArray
abstraction level, which might be problematic when hitting control flow
Using static_argnums
and static_argnames
force tracer to trace with concrete value instead of ShapedArray
they are considered compile-time constant, which will be constant-folded during tracing.
Changing those statics will trigger recompilation as it changes signature. Whether statics have changed or not depend on the __hash__
and __eq__
(statics must be hashable)
@functools.partial(jax.jit, static_argnums=1)
def f(x, a):
# during tracing, this shows 1 as it is static
print(a)
return jnp.sin(x)
f(jnp.zeros(1), 1) # this will compile
f(jnp.zeros(1), 2) # this will compile again as it is traced with actual number and signature change (from 1 to 2)
@jax.jit
def g(x, a):
# in tracing, Traced<ShapedArray(int32[], weak_type=True)
print(a)
return jnp.sin(x)
g(jnp.zeros(1), 1) # this will compile
g(jnp.zeros(1), 2) # this will not re-compile as its traced with ShapedArray(int32[]) and its signature does not change
2.2. jaxpr
Jaxpr is a sequence of primitives equations which are obtained by tracing, and this tracing process can be inspected using make_jaxpr
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
#invars: [a]
#outvars: [b]
#constvars: []
#equation: [a, 1] add [b] {}
#jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
examine_jaxpr(jax.make_jaxpr(foo)(5))
#Another approach to get jaxpr is to use traced stage
jax.jit(foo).trace(5).jaxpr
2.3. shape inference
A simple relevant untility is shape inference without doing actual FLOPs
import jax
import jax.numpy as jnp
f = lambda A, x: jnp.tanh(jnp.dot(A, x))
A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
out = jax.eval_shape(f, A, x) # no FLOPs performed
print(out.shape)
(2000, 1000)
print(out.dtype)
float32
3. Transform
See this doc to understand how to add a new transformation
many Jax transformation are probably done on top of jaxpr, it (mostly) transforms jaxpr into another jaxpr and works as an interpreter.
3.1. grad
Here is a great blog comparing static/dynamic autodifferentiation library
jax uses jvp to obtain a forward jaxpr, then applying transpose to each expression of the forward jaxpr to obtain grad
Consider the function \(f(x,y) = xy+y\), (during grad) jvp produces the following expression where xt, yt are traced abstractly
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
transpose then interpets them from backward and produce the following expressions:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
4. Compile
4.1. lower
Lowered
stage is after the Traced
stage, where traced jaxpr can be lowered into StableHLO in the MHLO form (mlir-HLO). Notice this is only translation, not optimization.
@jax.jit
def one_plus_one():
a = jnp.ones(1)
b = jnp.ones(1)
return a + b
print(one_plus_one.lower().as_text())
example output is
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main() -> (tensor<1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1xf32>
%2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
%4 = stablehlo.add %1, %3 : tensor<1xf32>
return %4 : tensor<1xf32>
}
}
lowering is achieved by implementing lowering rule for each Jaxpr primitive. this conversion is implemented here
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
4.2. compile
Lowered
stage can be futher compiled into Compiled
stage, which is implemented by XlaExecutable
Most of the optimization happen at this stage.
compiled = jax.jit(lambda x: x + 2).lower(3).compile()
# execute this
compiled(2)
# Array(4, dtype=int32, weak_type=True)
# some analysis
compiled.cost_analysis()
compiled.memory_analysis()
unfortunately as_text
as follows does not show executable binary directly, but the HLO
before native binary compilation. Notice that const 1 + 1 is optimized into 2 directly
HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[1]{0:T(256)}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.2 () -> f32[1] {
%constant.1 = f32[1]{0:T(256)} constant({2})
ROOT %copy.1 = f32[1]{0:T(256)} copy(f32[1]{0:T(256)} %constant.1)
}
Another similar API is jax.xla_computation
achieving similar compile results.
To see the actual? executable, use the following
from jax.experimental import serialize_executable
serialize_executable.serialize(compiled)
Internally, it will use xla_client
to compile
4.3. cache
compiled cache stored in memory can be deleted by
jax.clear_caches()
5. Distribution
5.1. sharding
See jax distribution doc, jax.sharding doc, also check this kaggle blog
5.1.1. Mesh and PartitionSpec
There are a few Sharding
, typically a XLACompatibleSharding
. It only specify how data get sharded, but it does not represent the sharded data itself.
SingleDeviceSharding
is a sharding that places its data on a single device (default sharding)
# tensor before placement is assigned to this default sharding
# it is placed on one of the device
x = jnp.zeros((8,16))
# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding
NamedSharding is a tuple of Mesh
and ParitionSpec
, it can be created manually as follows:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.experimental import mesh_utils
# create a mesh using either of the following
devices = np.array(jax.devices()).reshape(4, 2)
devices = mesh_utils.create_device_mesh((4, 2))
# Declare a 2D mesh with axes `x` and `y`.
mesh = Mesh(devices, ('x', 'y'))
# partitionspec describes how each input dim can be shareded using mesh axis
# e.g. PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh
# its argc should be equal to the tensor dim size which it is going to shard
spec = PartitionSpec('x', 'y')
# combine mesh and partitionspec together into a namedsharding
named_sharding = jax.sharding.NamedSharding(mesh, spec)
# shard the data and visualize
data = jnp.arange(8).reshape(2,4)
sharded_data = jax.device_put(data, named_sharding)
# this sharding info can be found in the array's attribute
print(sharded_data.sharding)
jax.debug.visualize_array_sharding(sharded_data)
┌──────────┬──────────┬──────────┬──────────┐
│ │ │ │ │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│ │ │ │ │
│ │ │ │ │
├──────────┼──────────┼──────────┼──────────┤
│ │ │ │ │
│ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
│ │ │ │ │
│ │ │ │ │
└──────────┴──────────┴──────────┴──────────┘
# another sharding
named_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
sharded_data = jax.device_put(data, named_sharding)
jax.debug.visualize_array_sharding(sharded_data)
┌────────────────────────────────────────────────┐
│ │
│ TPU 0,1,2,3 │
│ │
│ │
├────────────────────────────────────────────────┤
│ │
│ TPU 4,5,6,7 │
│ │
│ │
└────────────────────────────────────────────────┘
5.1.2. Placement
To control the device placement,
For explicit placement of variables, use jax.device_put
to actually shard data with a given Sharding
For shardings in function, use
in_shardings
andout_shardings
to control explicitly inputs and outputsjax.lax.with_sharding_constraint
to force the placement of intermediate result within jit decorated function
If in_shardings
and inputs' sharding is empty, it will be automatically computed by allowing spmd to be propagated back to the input parameters. This behavior is controlled by allow_spmd_sharding_propagation_to_parameters
mesh = Mesh(np.array(jax.devices()).reshape(2, 1), ('x', 'y'))
spec = P('x', 'y')
named_sharding = jax.sharding.NamedSharding(mesh, spec)
data = jnp.array([[1,2],[3,4]])
# Array([[1, 2],[3, 4]], dtype=int32)
sharded_data = jax.device_put(data, named_sharding)
# NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=tpu_hbm)
print(sharded_data.sharding)
# resharding using another spec
newspec = P('y', 'x')
new_sharding = jax.sharding.NamedSharding(mesh, newspec)
new_sharded_data = jax.device_put(sharded_data, new_sharding)
# resharding in jit decoration
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
return y
# use in_shardings and out_shardings
jit_f = jax.jit(f, in_shardings=(named_sharding,), out_shardings=named_sharding).lower(x, a).compile()
5.1.3. transfer guard
jax may transfer data between hosts and devices by
- explicit transfer:
jax.device_put(), jax.device_get()
- implicit transfer: e.g. print
DeviceArray
A transfer guard can be controlled as documented here
5.2. vmap and pmap
vmap
vectorize map, for example
import jax.numpy as jnp
vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
pmap
is the user-based SPMD programming
Applying pmap
to a function will compile the function with XLA, then replicates the function to executes each replica on its own XLA device in parallel.
# the splitting dim should not be larger than n_devices
n_devices = jax.local_device_count()
# create a sharded array
a = jax.pmap(lambda x: x**2)(jnp.arange(n_devices))
# shardingspec
# PmapSharding(sharding_spec=ShardingSpec((Unstacked(2),), (ShardedAxis(axis=0),)), device_ids=[0, 1], device_platform=TPU, device_shape=(2,))
print(a.sharding)
# get info of each shard
jax.debug.visualize_array_sharding(a)
for i, shard in enumerate(a.global_shards):
print(f"\nShard no: {i:>5}")
print(f"Device: {str(shard.device):>32}")
print(f"Data shape: {str(shard.data.shape):>8}")
print(f"Data slices: {str(shard.index):>22}\n")
print("="*75)
print("")
5.3. jit (pjit)
jit
is the Automatic SPMD programming where internal sharding is automatically propagated based on input/output sharding
The syntax is as follows:
pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue,...)
in_shardings
: match argument structure, optional as jax will infer the sharding from the input, when inference cannot be done, it will be default to replication (i.e. None)out_shardings
: also optional, when not specified, sharding will be inferred from GSPMD's sharding propagation.
mesh = Mesh(devices, ('x', 'y'))
named_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('x', 'y'))
def f(x):
return x + 1
jit_f = jax.jit(f, in_shardings=[named_sharding])
out = jit_f(jnp.arange(8).reshape((2,4)))
# sharding get propagated to:
# NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y'))
out.sharding
5.4. examples
See the following snippet for distributed computing examples
def mul(x, a):
x = jnp.dot(x,a)
return x
# single TPU core inference
x = jax.random.normal(jax.random.key(0), (8192, 8192))
a = jax.random.normal(jax.random.key(1), (8192, 8192))
# both x, a are placed on TPU 0
# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding
# ┌───────────────────────┐
# │ │
# │ │
# │ │
# │ │
# │ TPU 0 │
# │ │
# │ │
# │ │
# │ │
# └───────────────────────┘
jax.debug.visualize_array_sharding(a)
# compile before benchmark
mul_f = jax.jit(mul).lower(x, a).compile()
# 5 loops, best of 5: 4.77 ms per loop
%timeit -n 5 -r 5 mul_f(x, a).block_until_ready()
# shard x wrt 1 dim
mesh = Mesh(np.array(jax.devices()), ('x',))
named_sharding = NamedSharding(mesh, PartitionSpec('x',))
# data comm takes time, do this ahead of actual computing
rx = jax.device_put(x, named_sharding)
# ┌───────────────────────┐
# │ TPU 0 │
# ├───────────────────────┤
# │ TPU 1 │
# ├───────────────────────┤
# │ TPU 2 │
# ├───────────────────────┤
# │ TPU 3 │
# └───────────────────────┘
jax.debug.visualize_array_sharding(rx)
# replicate x into all devices
replicated_sharding = NamedSharding(mesh, (None, None))
ra = jax.device_put(x, replicated_sharding)
# ┌───────────────────────┐
# │ │
# │ │
# │ │
# │ │
# │ TPU 0,1,2,3 │
# │ │
# │ │
# │ │
# │ │
# └───────────────────────┘
jax.debug.visualize_array_sharding(ra)
sharded_mul_f = jax.jit(mul).lower(rx, ra).compile()
# 5 loops, best of 5: 1.29 ms per loop
%timeit -n 5 -r 5 sharded_mul_f(rx, ra).block_until_ready()
6. Serialization
6.1. StableHLO
See tutorial here how to serialize Jax into HLO
# serialize
import jax
from jax.experimental import export
import jax.numpy as jnp
import numpy as np
def plus(x,y):
return jnp.add(x,y)
# Create abstract input shapes:
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
# deserialize
from jax.lib import xla_client
mlir_module = xla_client._xla.mlir.deserialize_portable_artifact(mlir_bytes)
6.1.1. jax2tf
See jax2tf doc
from jax.experimental import jax2tf
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
def f_jax(x):
return jnp.sin(jnp.cos(x))
# jax2tf.convert is a higher-order function that returns a wrapped function with
# the same signature as your input function but accepting TensorFlow tensors (or
# variables) as input.
f_tf = jax2tf.convert(f_jax)
# For example you execute f_tf eagerly with valid TensorFlow inputs:
f_tf(np.random.random(...))
# Additionally you can use tools like `tf.function` to improve the execution
# time of your function, or to stage it out to a SavedModel:
f_tf_graph = tf.function(f_tf, autograph=False)
# the following one is the actual trigger that initialize tracing and lowering
# it builds StableHLO and encapsulate the StableHLO into tf.XlaCallModule ops. all StableHLO information is serialized into string as a module attribute of XlaCallModule node.
f_concrete = f_tf_graph.get_concrete_function(tf.constant(1.0))
# read XlaCallModule ops
f_concrete.graph.as_graph_def().node[2]
# decode bytes into readable format for debugging
mlir_bytes = concrete_f.graph.as_graph_def().node[6].attr['module'].s
print(jax.lib.xla_client._xla.mlir.deserialize_portable_artifact(mlir_bytes))
# depending on the sharding, this graph might also has a few XlaSharding ops, which can be decoded back as the following
sharding_bytes = concrete_f.graph.as_graph_def().node[4].attr['_XlaSharding'].s
xla_data_pb2.OpSharding.FromString(sharding_bytes)
Note shardings in jax.jit(fun, in_shardings=)
get translated into tf.XlaSharding
6.2. Checkpoint
Flax has a serialization library to save params as follows:
from flax import serialization
# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
# load, note that a template is needed to reconstruct from bytes
serialization.from_bytes(params, bytes_output)
7. flax
Since jax transformation is functional, linen is designed to do stateless functional transform as the following where \(v\) are variable collection, \(x,y\) are input output data.
7.1. linen
flax.linen.Module is a dataclass class, it only has hyperparameter members,(i.e. parameters are not stored inside)
from flax import linen as nn
from typing import Tuple
class Module(nn.Module):
# this declared hyperparameter is an instance variable and available through self.features immediately
features: Tuple[int, ...] = (16, 4)
# setup will not be called during init, it is called to bind scope EVERYTIME when calling init or apply
def setup(self):
# can only be accessed in apply, init. not directly from the object
self.dense1 = nn.Dense(self.features[0])
self.dense2 = nn.Dense(self.features[1])
# it get called after binding with setup
def __call__(self, x):
return self.dense2(nn.relu(self.dense1(x)))
variable is a container of "variable collections"
{
"params": {
"Conv1": { "weight": ..., "bias": ... },
"BatchNorm1": { "scale": ..., "mean": ... },
"Conv2": {...}
},
"batch_stats": {
"BatchNorm1": { "moving_mean": ..., "moving_average": ...}
}
}
param
is a special class of variable
. It is immutable. See this post
# this will retrive the value (if exists) or initialize it and return it (if not exists)
p = self.param('param_name', init_fn, shape, dtype) # shape, dtype are intializer's args (together with pseudo key)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value
init/apply
init
initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params
See the lifecycle doc for more details
model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
# eager init
params = model.init(key2, x)
# eager apply
# when mutable args exist, it returns a tuple of (out, muted_variables)
y = model.apply(params, x)
# jit apply
outs = jax.jit(model.apply)(params, x)
# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])
7.2. Rematerialization
rematerialization/checkpointing feature is supported in jax, see this doc
8. optax
Optimizer
usage example
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)