0x502 Frontend
- 1. Pytorch
- 1.1. API
- 1.2. Compile
- 1.3. Distribution
- 1.4. Frameworks
- 2. Tensorflow
- 2.1. API
- 2.2. Tracing
- 2.3. Transformation
- 2.4. Compile
- 2.5. Data
- 2.6. Distribution
- 2.7. Serialization
- 3. Jax
- 3.1. API
- 3.2. Tracing
- 3.3. Transformation
- 3.4. Compile
- 3.5. Distribution
- 3.6. flax
- 3.7. optax
Notes of (high?)-level frontend API for several frameworks
1. Pytorch
1.1. API
1.1.1. Tensor
Some tensor manipulations snippets
Initialization
## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))
## long long (torch.long)
torch.LongTensor([1,2,3])
## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)
## double (torch.double)
torch.DoubleTensor([1,2,3])
Operation
# element-wise operation
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)
# (2,3,4)
(a*a).shape
# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)
# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))
## bmm: batched version of mm
## matmul: lots of cases...
# following examples are from pytorch docs
# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

Masking
Masking is important. Without proper masking, attention will not work. Recall my own training failure...
Length Masking
Translation between mask and length
# length to mask
length = torch.LongTensor([3,5,4])
# tensor([[ True, True, True, False, False],
# [ True, True, True, True, True],
# [ True, True, True, True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)
# to get reversed mask, use ~mask
# mask to length
length = torch.sum(mask, dim=1)
apply mask
# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True, True, True, False, False],
[ True, True, True, True, True],
[ True, True, True, True, False]])
b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101, 0.5982, -inf, -inf],
# [ 0.6641, -0.8612, 0.2389, -1.1343, 1.1640],
# [ 1.1004, -1.2243, -1.1901, -0.5063, -inf]])
Triangular Masking
for transformer decoder
>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
[ True, True, False],
[ True, True, True]])
Padding
To pad, see the following code
# padding
a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
# tensor([[ 1, 2, 3],
# [ 3, 4, 0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
if max_len == -1:
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
To unpad
# get length
a = tensor([[ 1, 2, 3],
[ 3, 4, 0]])
mask = (a!=0)
length = torch.sum(mask, dim=1)
1.1.2. Modules
C binding implementation is here
Source code is available here
It looks each module object manage the following members:
_modules
: submodules_parameters
: tensor with gradient, these are to be optimizer with optim_buffers
: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm
state_dict
retrieves both parameters
and buffers
When we construct a new member in init, it did the registration automatically by overwriting the __setattr__
:
# a simplified version of setattr from pytorch
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
# register a parameter
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
self.register_parameter(name, value)
else:
# register a module
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
modules[name] = value
Those members would be later __getattr__
by checking against __modules
, _parameters
, _buffers
explicitly.
1.2. Compile
two libraries support this feature
TorchDynamo: capture graph structure fx.graph
with dynamic python bytecode transformation
TorchInductor compile graph to the machine code by leverating triton and openMP
1.3. Distribution
1.4. Frameworks
1.4.1. Datasets
Dataset
provides the map style dataset, which can do fast random access by using apache arrow as in-memory column format, which is cached on disk.
# Many raw dataset will be converted into arrow cache during loading
data_files = {"train": ["path/to/data.csv"]}
my_dataset = load_dataset("csv", data_files=data_files, split="train")
# transformation will trigger processing on all data immediately and build new cache
my_dataset = my_dataset.map(process_fn)
IterableDataset
is the iterable dataset, it dynamically load data from disk (without arrow conversion) and transformation are applied on the fly. See this doc
It can be creatd by load_dataset's streaming mode or generator
# streaming mode
imagenet = load_dataset("imagenet-1k", split="train", streaming=True) # will start loading the data when iterated over
for example in imagenet:
print(example)
break
def my_generator(n):
for i in range(n):
yield {"col_1": i}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs={"n": 10})
1.4.2. Transformers
2. Tensorflow
The execution path is roughly
2.1. API
2.1.1. Tensor
tf.Tensor
is immutable tensor. A number of specialized tensors are available: see tf.Variable
, tf.constant
, tf.placeholder
and tf.RaggedTensor
tf.Variable
is mutable tensor and requires an initial value for the variable and maintains states during execution.
v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)
tf.Module
is a named container for tf.Variables
, it has variables, trainable_variables and submodules.
2.1.2. Operation (Node)
An tf.Operation
is a node in a tf.Graph
that takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output
See the list of raw_ops here
See this doc for control-flow ops implementation
2.2. Tracing
When a function decorated with @tf.function
get executed, it will be traced (not when they get defined) into tf.Graph
. During tracing, shape inference is also happening. Traced graph will be cached.
Importantly, tf.Graph
-level tracing allows unknown shape by feeding something like tf.TensorSpec([1, None])
to input_signature. This prevents tf from retracing the graph when specialising shape are feeded.
2.2.1. TF1 Graph
A tf.Graph
is the raw, language-agnostic, portable representation of a TensorFlow computation
In tensorflow 1, a tf.Session
stores the state of a graph (i.e. the values of the variables).
A session may own resources, such as tf.Variable
, tf.queue.QueueBase
, and tf.compat.v1.ReaderBase
In Tf1, graph is serialized using GraphDef
protobuf, here is the doc, it is saved using tf.compat.v1.train.Saver
It is serialized inside the meta
file
2.2.2. PolymorphicFunction
tf.Graph
is not enabled by default in eager mode, tf2 builds a graph by decorating tf.function
and tracing, see tf.function doc and PolymorphicFunction doc
Basically, tf.function
creates a PolymorphicFunction (or GenericFunction), which can encapsulates several tf.Graph
, further decorating with jit_compile=True
will trigger the compilation
PolymorphicFunction tf.function wraps a Python function, returning a PolymorphicFunction
object. It manages a set of ConcreteFunction
s and automatically picks the right one for your inputs. The interface can be forced with input_signature
with tf.TensorSpec
# f is a PolymorphicFunction
@tf.function
def f(x):
return x + 1
# input_signature can be constrainted with tf.TensorSpec
@tf.function(input_signature=[tf.TensorSpec([1, None])])
def constrained_foo(t):
print("tracing...")
return t
2.2.3. ConcreteFunction
ConcreteFunction when fed with specific input into GenericFunction
, tracing creates a tf.Graph and wraps it in a ConcreteFunction
, also known as a trace.
A ConcreteFunction
manages AtomicFunction
and captured inputs. AtomicFunction further contains the actual tf.Graph (FuncGraph
). concrete_function.graph
Signature of concrete function can be accessed by concrete_function.inputs
and concrete_function.outputs
. Both of them are SymbolicTensor
.
# f.get_concrete_function(1) returns a ConcreteFunction
concrete_function = f.get_concrete_function(tf.constant(1, dtype=tf.int32))
# inputs/outputs of concrete_function are SymbolicTensor (i.e. not eager tensor)
assert tf.is_symbolic_tensor(concrete_function.inputs[0])
assert tf.is_symbolic_tensor(concrete_function.outputs[0])
print(concrete_function.function_def)
# signature {
# name: "__inference_f_51110"
# input_arg {
# name: "x"
# type: DT_INT32
# }
# output_arg {
# name: "identity"
# type: DT_INT32
# }
# }
# node_def {
# name: "add/y"
# op: "Const"
# attr {
# key: "dtype"
# value {
# type: DT_INT32
# }
# }
# attr {
# key: "value"
# value {
# tensor {
# dtype: DT_INT32
# tensor_shape {
# }
# int_val: 1
# }
# }
# }
# }...
2.2.3.1. AtomicFunction (FuncGraph)
AtomicFunction
wraps a FuncGraph
in its cached graph, which is the actual graph in ConcreteFunction
. The input to atomic function is typically user-inputs and captured inputs (e.g. variables)
AtomicFunction is callable and can be extracted by inference_fn
# AtomicFunction
concrete_function.inference_fn
# FuncGraph
graph = concrete_function.inference_fn.graph
# or simply
graph = concrete_function.graph
print(graph.as_graph_def())
# node {
# name: "x"
# op: "Placeholder"
# attr {
# key: "_user_specified_name"
# value {
# s: "x"
# }
# }
# attr {
# key: "dtype"
# value {
# type: DT_INT32
# }
# }
# attr {
# key: "shape"
# value {
# shape {
# }
# }
# }
# }
print(graph.operations)
# [<tf.Operation 'x' type=Placeholder>,
# <tf.Operation 'add/y' type=Const>,
# <tf.Operation 'add' type=AddV2>,
# <tf.Operation 'Identity' type=Identity>]
2.2.4. Captured Objects
Many objects are captured during tracing. concrete_function.captured_inputs
are implicitly passed to the function as args next time after tracing.
See the following example:
a = tf.Variable([3.14, 3.14]) # captured as ResourceHandle (i.e just pointer)
b = tf.constant(1.0) # captured as actual tensor
c = 2. # built into graph's Add node attribute directly
@tf.function
def f(x):
print('tracing...')
d = tf.constant(3.) # built into graph's Const node
return x + a + b + c + d
#<tf.Tensor: shape=(2,), dtype=float32, numpy=array([9.14, 9.14], dtype=float32)>
f(0.) # trigger tracing, 9.14 = 3.14 + 1 + 2 + 3
# [<tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Variable/14", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [2] ]")>>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
# a, b are captured as inputs here
f.get_concrete_function(0.).captured_inputs
# making any changes to a,b,c will not cause retracing or impact on the results, for example
b = tf.constant(2.0)
f(0.) # no re-tracing, result are still 9.14
f.get_concrete_function(0.).captured_inputs # 2nd captured input is still 1.0
# make assign can impact result but without retracing, this is because the captured_input for variable is ResourceHandle (pointer) not the actual values
a.assign([4.14, 4.14])
f(0.) # no-retracing, 10.14 = 4.14 + 1 + 2 + 3
2.3. Transformation
2.3.1. Placer
2.3.2. Grappler
2.3.3. Partitioner
Partitioner split the graph on different device and communication primitive (e.g. send/recv) is inserted to transfer data.
Actual communication should implement rendezvous interface. See this blog
2.4. Compile
2.4.1. StableHLO
Each traced ConcreteFunction (tf.graph) can be compiled when the shape changes, this is because XLA needs static shape but ConcreteFunction allows dynamic shape None, for example, batch size.
Note that compiling can happen even without retracing. see the following example:
@tf.function(input_signature=[tf.TensorSpec([1, None])], jit_compile=True)
def f(t):
print("tracing...")
return tf.sin(t)
# tracing...
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,1]{1,0})->f32[1,1]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,]]))(stage='hlo')
# not triggering tracing, but trigging compile. notice the shape of args has changed
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,2]{1,0})->f32[1,2]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')
Compiled results will also get cached
2.4.2. XLA Compile
StableHLO to native code
2.5. Data
2.5.1. Example proto
A single datapoint can be represented using tf.train.Example
proto, which serializes a single \((x,y)\) datapoint into binary format.
An Example
essentially represent the following structure:
Dict[str,
Union[List[bytes],
List[int64],
List[float]]]
where it maps feature name (string key) to corresponding value tf.train.Feature
which can be a list of int, float, byte. The proto def can be seen here
message Example {
Features features = 1;
}
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
}
message BytesList {
repeated bytes value = 1;
}
2.5.1.1. Native Conversion
This section lists a few Example related conversion
To convert native/numpy types into tf.train.Feature
, we can use the following snippet
# native to feature
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
feature = _int64_feature(1)
# feature to native
feature.int64_list.value
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
2.5.2. tf.io package
Each special byte sequence should be using its own encoder/decoder, tf.io
package provides those features
tensor conversion to convert tensor type into tf.train.Feature
, we can do the following
# convert to bytelist tensor
t = tf.constant(1)
serialized_tensor = tf.io.serialize_tensor(t)
feature_of_bytes = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[serialized_tensor.numpy()]))
# convert back, type needed to be provided and matched
tf.io.parse_tensor(serialized_tensor, tf.int32)
example conversion
tf.io.parse_example
2.5.3. tf.data
tf.data.Dataset
API for input pipelines
2.5.3.1. Naive Dataset
use from_tensors
for a single datapoint and from_tensor_slices
for multiple datapoints
# tf.data.Dataset.from_tensor_slices((X,y)) or tf.data.Dataset.from_tensor_slices(X)
dataset = tf.data.Dataset.from_tensor_slices(([np.array([1,2,3,4]), np.array([0,0,1,0])], np.array([1,0])))
# <_TensorSliceDataset element_spec=(TensorSpec(shape=(4,), dtype=tf.int64, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))>
print(dataset)
# (<tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 2, 3, 4])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
#(<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 1, 0])>, <tf.Tensor: shape=(), dtype=int64, numpy=0>)
for elem in iter(dataset):
print(tf.data)
use from_generator
to consume a python generator
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
2.5.3.2. Batching
use batch
or padded_batch
def gen_series():
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
ds_series = tf.data.Dataset.from_generator(
gen_series,
output_types=(tf.int32, tf.float32),
output_shapes=((), (None,)))
# batch 2 datapoint with max shape 8
# if each sample is a dict, the 2nd arg padded_shape can take a dict mapping each key to its max length or [] to indicate list without padding
ds_series_batch = ds_series.shuffle(20).padded_batch(2, 8)
_, sequence_batch = next(iter(ds_series_batch))
print(sequence_batch.numpy())
[[ 0.0208 0. 0. 0. 0. 0. 0. 0.],
[-1.2131 0.523 1.083 0.3762 -1.1041 -1.6604 -2.3436, 0 ]]
2.5.3.3. Tokenize
can use tokenizer in tensorflow_text
tokenizer = text.WhitespaceTokenizer()
dataset = dataset.map(lambda x: tokenizer.tokenize(x))
2.5.3.4. TFRecordDataset
an Example
proto can be serialized/deserialized into/from string
# convert example to string
string_example = example_proto.SerializeToString()
# convert back
example_proto = tf.train.Example.FromString(serialized_example.numpy())
TFRecord
is a file format storing a sequence of byte sequences, it can be used to store example strings.
# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example.numpy())
# Read
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
2.5.4. tfds
See doc here
tfds provides a ready-to-use datasets
tfds.load
is roughly equivalently doing the following:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
2.6. Distribution
All tensorflow's visible devices can be retrieved with
# this typically includes CPU as well
logging.info("All TF devices: %s", tf.config.list_logical_devices())
2.7. Serialization
There are two related concepts:
checkpoint
: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.SavedModel
: includes a serialized description of the computation at thetf.Graph
level in addition to the parameter values (checkpoint)
2.7.1. Save
In Tf2, graph can be stored as a SavedModel format, in saved_model.pb
, graph is serialized using MetaGraphDef
protobuf (it contains GraphDef as a child). This in general encapsulate tf.Graph
level information (i.e. not at the StableHLO level or compiled machine code level)
SavedModel typically saved a Trackable
obj (typically tf.Module
). SavedModel will save
signature
(alias ofConcreteFunction
): tf.function decorated with explicit input_signature will be saved, or target signatures should be passed as an argument.checkpoint
(recursive attributes of trackable objects): saved tensor variables
class Adder(tf.Module):
def __init__(self):
self.weight = tf.Variable(3.14, dtype=tf.float32)
@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def __call__(self, x):
return x + self.weight
model = Adder()
tf.saved_model.save(adder, "/tmp/adder")
This will create the following files:
-
saved_model.pb
: it is aMetaGraphDef
protobuf file, which containsGraphDef
as a child. Typically a small file. -
variables
: checkpoint directory, it contain files such as variables.index, variables.data-00000-of-00001. See the next section
2.7.2. Checkpoint
Checkpoint typically has two files (ref: stackoverflow)
- index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized
BundleEntryProto
. Each BundleEntryProto describes the metadata of a tensor: which "data" file contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc. - data file: it is
TensorBundle
collection, save the values of all variables.
To inspect variables (ckpt) from a SavedModel
f = tf.saved_model.load("/tmp/adder")
# trackable attributes are accessible after loading
# this info is stored in MetaGraphDef.object_graph_def
# <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>
f.weight
# it is also accessible from signature.variables
signature = f.signatures['serving_default']
# (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>,)
print(signature.variables)
# variables can also be loaded with load_checkpoint
reader = tf.train.load_checkpoint("/tmp/adder/variables/variables")
# {'_CHECKPOINTABLE_OBJECT_GRAPH': tf.string, 'weight/.ATTRIBUTES/VARIABLE_VALUE': tf.float32}
dtype_from_key = reader.get_variable_to_dtype_map()
shape_from_key = reader.get_variable_to_shape_map()
key = 'weight/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
reader.get_tensor(key)
Note that only Trackable
objects and attributes will be saved, none-trackable tf objects (e.g. tf.constant) are saved in concrete_function.captured_input
, native python is directly built into the model
It is also possible to save/restore a checkpoint without using SavedModel
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
2.7.3. Signature (ConcreteFunction)
To load a SavedModel and use its concrete functions:
f = tf.saved_model.load("/tmp/adder")
# use its signature (concrete function)
saved_concrete_function = f.signatures['serving_default']
# <ConcreteFunction (*, x: TensorSpec(shape=(), dtype=tf.float32, name='x')) -> Dict[['output_0', TensorSpec(shape=(), dtype=tf.float32, name='output_0')]] at ...>
print(saved_concrete_function)
# {'output_0': <tf.Tensor: shape=(), dtype=float32, numpy=4.1400003>}
saved_concrete_function(tf.constant(1.0))
# error: concrete function cannot run over incompatible input_signature
saved_concrete_function(tf.constant(1))
To inspect nodes of a graph from a signature
graph = f.signatures['serving_default'].graph
for node in graph.as_graph_def().node:
print(node)
Note the graph before saving and the graph after loading is not identical. Compare the following two graph_def. Saved graph typically encapsulates most computing into a single raw_ops.PartitionedCall
refering to some __inference___call___
function in the function library. See its doc
class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def __call__(self, x):
return x + 3.14
# before saving
adder = Adder()
adder.__call__.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()
# after saving
tf.saved_model.save(adder, "/tmp/adder")
saved_adder = tf.saved_model.load("/tmp/adder")
saved_adder_concrete_function = saved_adder.signatures['serving_default']
saved_adder_concrete_function.graph.as_graph_def()
Marking jit_compile=True
will produce the same graph before saving, but will attach _XlaMustCompile=True
attribute to PartitionedCall
during saving. This probably will trigger jit after loading
3. Jax
Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.
Here is a great blog comparing static/dynamic autodifferentiation library
The execution path is roughly
3.1. API
3.1.1. lax
lax is a lib of primitive operations underlying jnp, it is a thin wrapper over XLA and we know how to transform
See this section about lax control flow
3.1.2. primitive
we can define new primitives and implement its evaluation, transformation methods
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
def multiply_add_prim(x, y, z):
return multiply_add_p.bind(x, y, z)
3.1.2.1. eval
primal eval should be implemented
def multiply_add_impl(x, y, z):
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
To enable tracing, abstract eval should be implemented to inform how to infer shape and type
def multiply_add_abstract_eval(xs, ys, zs):
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
# with abstract evaluation, we can tracing it correctly
# { lambda ; a:f32[] b:f32[] c:f32[]. let d:f32[] = multiply_add a b c in (d,) }
api.make_jaxpr(multiply_add_prim)(2.0, 3.0, 10.)
3.1.2.2. differentiation
define jvp and transpose
3.1.3. jnp
similar to the numpy syntax
3.1.3.1. packing and segments
sometimes we want to do packing to efficiently use TPU.
Here is a example illustrating packing from tensor2tensor library. It is concatenating sentence 'horizontally' and tracking sentence by a segment id.
It can be combined with batching as well (batching is to concatenate padded sequence 'vertically')
Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
The output example is:
{
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
"inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
"inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
"targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
"targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}
3.1.3.2. paddings and masking
some paddings related code snippets
# create padding from ids
# seq = [1,2,3,0,0]
# padding = [0, 0, 0, 1, 1]
jnp.equal(seq, 0)
# apply paddings
masked_features = features * (1.0-paddings[:, :, jnp.newaxis])
# padding to length
seq_lengths = jnp.sum(1.0 - paddings, axis=-1)
# length to paddings
# DeviceArray([[ True, True, True, True],
# [ True, True, False, False],
# [ True, False, False, False]], dtype=bool)
lengths = jnp.array([4,2,1])
col_ids = jnp.arange(4)[jnp.newaxis, :]
col_ids < length[:, jnp.newaxis]
# Look-ahead mask or causal mask
def create_look_ahead_mask(size: int) -> np.ndarray:
mask = np.ones([size, size])
mask[np.tril_indices(size)] = 0.0
return mask # (seq_len, seq_len)
# (array([0, 1, 1]), array([0, 0, 1]))
np.tril_indices(2)
#array([[0., 1., 1.],
# [0., 0., 1.],
# [0., 0., 0.]])
create_look_ahead_mask(3)
def another_look_ahead_mask(size: int)
# causal mask using col/row id comparison
col_idx = jnp.tile(jnp.arange(size)[jnp.newaxis, :], [size, 1])
row_idx = jnp.tile(jnp.arange(size)[:, jnp.newaxis], [1, size])
mask = (row_idx < col_idx)
return mask
# packed segment masking
# [B,T,1]
fst_segment_ids = segment_ids[:, :, jnp.newaxis]
# [B,1,T]
snd_segment_ids = segment_ids[;, jnp.newaxis, :]
jnp.not_equal(fst_segment_ids, snd_segment_ids)
3.2. Tracing
Different transformation requires different level of abstraction. They might have some issues depending on the control flow, see doc here
Target function should contain args/return of arrays, scalars or nested containers. Other values (e.g. str) will throw error during tracing
Note that tracing might happen by just using jnp
as some jnp is implemented with jit, for example,
x = jnp.zeros((3,3))
# WARNING:jax._src.dispatch:Finished tracing + transforming broadcast_in_dim for pjit in 0.003103494644165039 sec
# WARNING:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
#WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.005272388458251953 sec
#WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.016413211822509766 sec
3.2.1. Tracing Abstraction
By default, tracing is done using ShapedArray
abstraction level, which might be problematic when hitting control flow
Using static_argnums
and static_argnames
force tracer to trace with concrete value instead of ShapedArray
they are considered compile-time constant, which will be constant-folded during tracing.
Changing those statics will trigger recompilation as it changes signature. Whether statics have changed or not depend on the __hash__
and __eq__
(statics must be hashable)
@functools.partial(jax.jit, static_argnums=1)
def f(x, a):
# during tracing, this shows 1 as it is static
print(a)
return jnp.sin(x)
f(jnp.zeros(1), 1) # this will compile
f(jnp.zeros(1), 2) # this will compile again as it is traced with actual number and signature change
@jax.jit
def g(x, a):
# in tracing, Traced<ShapedArray(int32[], weak_type=True)
print(a)
return jnp.sin(x)
g(jnp.zeros(1), 1) # this will compile
g(jnp.zeros(1), 2) # this will not re-compile as its traced with ShapedArray(int32[]) and its signature does not change
3.2.2. jaxpr
Jaxpr is a sequence of primitives equations which are obtained by tracing, and this tracing process can be inspected using make_jaxpr
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
#foo
#=====
#invars: [a]
#outvars: [b]
#constvars: []
#equation: [a, 1] add [b] {}
#jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
3.3. Transformation
See this doc to understand how to add a new transformation
many Jax transformation are probably done on top of jaxpr, it (mostly) transforms jaxpr into another jaxpr and works as an interpreter.
3.3.1. grad
jax uses jvp to obtain a forward jaxpr, then applying transpose to each expression of the forward jaxpr to obtain grad
Consider the function \(f(x,y) = xy+y\), (during grad) jvp produces the following expression where xt, yt are traced abstractly
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
transpose then interpets them from backward and produce the following expressions:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
3.4. Compile
3.4.1. lower
jaxpr which can be lowered into HLO (XLA's IR). Notice this is only translation, not optimization.
@jax.jit
def one_plus_one():
a = jnp.ones(1)
b = jnp.ones(1)
return a + b
print(one_plus_one.lower().as_text())
# module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# func.func public @main() -> (tensor<1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
# %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
# %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1xf32>
# %2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
# %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
# %4 = stablehlo.add %1, %3 : tensor<1xf32>
# return %4 : tensor<1xf32>
# }
# }
lowering is achieved by implementing lowering rule for each primitive
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
3.4.2. compile
HLO can be futher lowered into executable by compiling using XLA. Most of the optimization happen at this stage.
compiled = jax.jit(lambda x: x + 2).lower(3).compile()
# HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[1]{0:T(256)}}, allow_spmd_sharding_propagation_to_output={true}
# ENTRY %main.2 () -> f32[1] {
# %constant.1 = f32[1]{0:T(256)} constant({2})
# ROOT %copy.1 = f32[1]{0:T(256)} copy(f32[1]{0:T(256)} %constant.1)
# }
print(one_plus_one.lower().compile().as_text()) # notice that const 1 + 1 is optimized into 2 directly
3.4.3. jit
jit decorator combines all the previous steps: tracing, transforming, lowering and compiling.
3.4.4. jax2tf
See jax2tf doc
from jax.experimental import jax2tf
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
def f_jax(x):
return jnp.sin(jnp.cos(x))
# jax2tf.convert is a higher-order function that returns a wrapped function with
# the same signature as your input function but accepting TensorFlow tensors (or
# variables) as input.
f_tf = jax2tf.convert(f_jax)
# For example you execute f_tf eagerly with valid TensorFlow inputs:
f_tf(np.random.random(...))
# Additionally you can use tools like `tf.function` to improve the execution
# time of your function, or to stage it out to a SavedModel:
f_tf_graph = tf.function(f_tf, autograph=False)
# this triggers the jaxpr tracing, built StableHLO and encapsulate the StableHLO into tf.XlaCallModule ops. all StableHLO information is serialized into string as a module attribute of XlaCallModule node.
f_tf_graph.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()
3.5. Distribution
3.5.1. Sharding
See jax distribution doc, jax.sharding doc, also check this kaggle blog
There are a few Sharding
, typically a XLACompatibleSharding
. It only specify how data get sharded, but it does not represent the sharded data itself.
SingleDeviceSharding a sharding that places its data on a single device (default sharding)
# tensor before placement is assigned to this default sharding
# it is placed on one of the device
x = jnp.zeros((8,16))
# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding
NamedSharding is a tuple of Mesh
and ParitionSpec
, it can be created manually as follows:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
# create a mesh
devices = np.array(jax.devices()).reshape(4, 2)
# Declare a 2D mesh with axes `x` and `y`.
mesh = Mesh(devices, ('x', 'y'))
# partitionspec describes how each input dim can be shareded using mesh axis
# e.g. PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh
# its argc should be equal to the tensor dim size which it is going to shard
spec = PartitionSpec('x', 'y')
# combine mesh and partitionspec together into a namedsharding
named_sharding = jax.sharding.NamedSharding(mesh, spec)
# shard the data and visualize
data = jnp.arange(8).reshape(2,4)
sharded_data = jax.device_put(data, named_sharding)
# this sharding info can be found in the array's attribute
print(sharded_data.sharding)
jax.debug.visualize_array_sharding(sharded_data)
┌──────────┬──────────┬──────────┬──────────┐
│ │ │ │ │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│ │ │ │ │
│ │ │ │ │
├──────────┼──────────┼──────────┼──────────┤
│ │ │ │ │
│ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
│ │ │ │ │
│ │ │ │ │
└──────────┴──────────┴──────────┴──────────┘
# another sharding
named_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
sharded_data = jax.device_put(data, named_sharding)
jax.debug.visualize_array_sharding(sharded_data)
┌────────────────────────────────────────────────┐
│ │
│ TPU 0,1,2,3 │
│ │
│ │
├────────────────────────────────────────────────┤
│ │
│ TPU 4,5,6,7 │
│ │
│ │
└────────────────────────────────────────────────┘
To control the device placement,
- for input data, use
jax.device_put
to actually shard data with a given Sharding - for intermediate result within jit decorated function, use
jax.lax.with_sharding_constraint
to force the placement
mesh = Mesh(np.array(jax.devices()).reshape(2, 1), ('x', 'y'))
spec = P('x', 'y')
named_sharding = jax.sharding.NamedSharding(mesh, spec)
data = jnp.array([[1,2],[3,4]])
# Array([[1, 2],[3, 4]], dtype=int32)
sharded_data = jax.device_put(data, named_sharding)
# NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=tpu_hbm)
print(sharded_data.sharding)
# resharding using another spec
newspec = P('y', 'x')
new_sharding = jax.sharding.NamedSharding(mesh, newspec)
new_sharded_data = jax.device_put(sharded_data, new_sharding)
# resharding in jit decoration
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
return y
3.5.2. vmap
vmap vectorize map, for example
import jax.numpy as jnp
vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
3.5.3. pmap
pmap
is the user-based SPMD programming
Applying pmap
to a function will compile the function with XLA, then replicates the function to executes each replica on its own XLA device in parallel.
# the splitting dim should not be larger than n_devices
n_devices = jax.local_device_count()
# create a sharded array
a = jax.pmap(lambda x: x**2)(jnp.arange(n_devices))
# shardingspec
# PmapSharding(sharding_spec=ShardingSpec((Unstacked(2),), (ShardedAxis(axis=0),)), device_ids=[0, 1], device_platform=TPU, device_shape=(2,))
print(a.sharding)
# get info of each shard
jax.debug.visualize_array_sharding(a)
for i, shard in enumerate(a.global_shards):
print(f"\nShard no: {i:>5}")
print(f"Device: {str(shard.device):>32}")
print(f"Data shape: {str(shard.data.shape):>8}")
print(f"Data slices: {str(shard.index):>22}\n")
print("="*75)
print("")
3.5.4. jit (pjit)
jit
is the Automatic SPMD programming where internal sharding is automatically propagated based on input/output sharding
The syntax is as follows:
pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue,...)
in_sharding
: optional as jax will infer the sharding from the input, when inference cannot be done, it will be default to replication (i.e. None)out_sharding
: also optional, hwne not specified, sharding will be inferred from GSPMD's sharding propagation.
3.6. flax
Since jax transformation is functional, linen is designed to do stateless functional transform as the following where \(v\) are variable collection, \(x,y\) are input output data.
3.6.1. linen
flax.linen.Module is dataclass-like class, it only has hyperparameter members,(i.e. parameters are not stored inside)
from flax import linen as nn
from typing import Tuple
class Module(nn.Module):
# this hyp becomes an instance variable and available through self.features immediately
features: Tuple[int, ...] = (16, 4)
# setup will not be called during init, it is called to bind scope EVERYTIME when calling init or apply
def setup(self):
# can only be accessed in apply, init. not directly from the object
self.dense1 = nn.Dense(self.features[0])
self.dense2 = nn.Dense(self.features[1])
# it get called after binding with setup
def __call__(self, x):
return self.dense2(nn.relu(self.dense1(x)))
variable is a container of "variable collections"
{
"params": {
"Conv1": { "weight": ..., "bias": ... },
"BatchNorm1": { "scale": ..., "mean": ... },
"Conv2": {...}
},
"batch_stats": {
"BatchNorm1": { "moving_mean": ..., "moving_average": ...}
}
}
param
is a special class of variable
. It is immutable. See this post
# this will retrive the value (if exists) or initialize it and return it (if not exists)
p = self.param('param_name', init_fn, shape, dtype)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value
init/apply
init
initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params
See the lifecycle doc for more details
model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
# init
params = model.init(key2, x)
# apply
y = model.apply(params, x)
# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])
3.6.2. serialization
usage example
from flax import serialization
# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
# load, note that a template is needed to reconstruct from bytes
serialization.from_bytes(params, bytes_output)
3.6.3. Rematerialization
rematerialization/checkpointing feature is supported in jax, see this doc
3.7. optax
Optimizer
usage example
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)