Skip to content

0x502 Frontend

Notes of (high?)-level frontend API for several frameworks

1. Pytorch

1.1. API

1.1.1. Tensor

Some tensor manipulations snippets

Initialization

## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))

## long long (torch.long)
torch.LongTensor([1,2,3])

## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)


## double (torch.double)
torch.DoubleTensor([1,2,3])

Operation

# element-wise operation 
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)

# (2,3,4)
(a*a).shape

# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)

# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))

## bmm: batched version of mm

## matmul: lots of cases...
# following examples are from pytorch docs

# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])


# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication

>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

Masking

Masking is important. Without proper masking, attention will not work. Recall my own training failure...

Length Masking

Translation between mask and length

# length to mask
length = torch.LongTensor([3,5,4])


# tensor([[ True,  True,  True, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True,  True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)

# to get reversed mask, use ~mask

# mask to length
length = torch.sum(mask, dim=1)

apply mask

# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])

b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101,  0.5982,    -inf,    -inf],
#        [ 0.6641, -0.8612,  0.2389, -1.1343,  1.1640],
#        [ 1.1004, -1.2243, -1.1901, -0.5063,    -inf]])

Triangular Masking

for transformer decoder

>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

Padding

To pad, see the following code

# padding
 a = [torch.tensor([1,2,3]), torch.tensor([3,4])]

# tensor([[ 1,  2,  3],
#   [ 3,  4,  0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)

# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
    """Perform padding for the list of tensors.

    Args:
        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
        pad_value (float): Value for padding.

    Returns:
        Tensor: Padded tensor (B, Tmax, `*`).

    Examples:
        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
        >>> x
        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
        >>> pad_list(x, 0)
        tensor([[1., 1., 1., 1.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.]])

    """
    n_batch = len(xs)

    if max_len == -1:
        max_len = max(x.size(0) for x in xs)

    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)

    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]

    return pad

To unpad

# get length
a = tensor([[ 1,  2,  3],
            [ 3,  4,  0]])

mask = (a!=0)
length = torch.sum(mask, dim=1)

1.1.2. Modules

C binding implementation is here

Source code is available here

It looks each module object manage the following members:

  • _modules: submodules
  • _parameters: tensor with gradient, these are to be optimizer with optim
  • _buffers: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm

state_dict retrieves both parameters and buffers

When we construct a new member in init, it did the registration automatically by overwriting the __setattr__:

    # a simplified version of setattr from pytorch
    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

        # register a parameter
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            self.register_parameter(name, value)
        else:

            # register a module
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                modules[name] = value

Those members would be later __getattr__ by checking against __modules, _parameters, _buffers explicitly.

1.2. Compile

two libraries support this feature

TorchDynamo: capture graph structure fx.graph with dynamic python bytecode transformation

TorchInductor compile graph to the machine code by leverating triton and openMP

1.3. Distribution

1.4. Frameworks

1.4.1. Datasets

Dataset provides the map style dataset, which can do fast random access by using apache arrow as in-memory column format, which is cached on disk.

# Many raw dataset will be converted into arrow cache during loading
data_files = {"train": ["path/to/data.csv"]}
my_dataset = load_dataset("csv", data_files=data_files, split="train")

# transformation will trigger processing on all data immediately and build new cache
my_dataset = my_dataset.map(process_fn)

IterableDataset is the iterable dataset, it dynamically load data from disk (without arrow conversion) and transformation are applied on the fly. See this doc

It can be creatd by load_dataset's streaming mode or generator

# streaming mode
imagenet = load_dataset("imagenet-1k", split="train", streaming=True)  # will start loading the data when iterated over
for example in imagenet:
    print(example)
    break


def my_generator(n):
    for i in range(n):
        yield {"col_1": i}

my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs={"n": 10})

1.4.2. Transformers

2. Tensorflow

The execution path is roughly

\[\text{python} \to[\text{trace}]\to \text{tf.Graph} \to[\text{jit compile}]\to \text{StableHLO} \to[\text{jit compile}]\to \text{native}\]

2.1. API

2.1.1. Tensor

tf.Tensor is immutable tensor. A number of specialized tensors are available: see tf.Variable, tf.constant, tf.placeholder and tf.RaggedTensor

tf.Variable is mutable tensor and requires an initial value for the variable and maintains states during execution.

v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)

tf.Module is a named container for tf.Variables, it has variables, trainable_variables and submodules.

2.1.2. Operation (Node)

An tf.Operation is a node in a tf.Graph that takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output

See the list of raw_ops here

See this doc for control-flow ops implementation

2.2. Tracing

When a function decorated with @tf.function get executed, it will be traced (not when they get defined) into tf.Graph. During tracing, shape inference is also happening. Traced graph will be cached.

Importantly, tf.Graph-level tracing allows unknown shape by feeding something like tf.TensorSpec([1, None]) to input_signature. This prevents tf from retracing the graph when specialising shape are feeded.

2.2.1. TF1 Graph

A tf.Graph is the raw, language-agnostic, portable representation of a TensorFlow computation

In tensorflow 1, a tf.Session stores the state of a graph (i.e. the values of the variables).

A session may own resources, such as tf.Variable, tf.queue.QueueBase, and tf.compat.v1.ReaderBase

In Tf1, graph is serialized using GraphDef protobuf, here is the doc, it is saved using tf.compat.v1.train.Saver

It is serialized inside the meta file

2.2.2. PolymorphicFunction

tf.Graph is not enabled by default in eager mode, tf2 builds a graph by decorating tf.function and tracing, see tf.function doc and PolymorphicFunction doc

Basically, tf.function creates a PolymorphicFunction (or GenericFunction), which can encapsulates several tf.Graph, further decorating with jit_compile=True will trigger the compilation

PolymorphicFunction tf.function wraps a Python function, returning a PolymorphicFunction object. It manages a set of ConcreteFunctions and automatically picks the right one for your inputs. The interface can be forced with input_signature with tf.TensorSpec

# f is a PolymorphicFunction
@tf.function
def f(x):
  return x + 1

# input_signature can be constrainted with tf.TensorSpec
@tf.function(input_signature=[tf.TensorSpec([1, None])])
def constrained_foo(t):
  print("tracing...")
  return t

2.2.3. ConcreteFunction

ConcreteFunction when fed with specific input into GenericFunction, tracing creates a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.

A ConcreteFunction manages AtomicFunction and captured inputs. AtomicFunction further contains the actual tf.Graph (FuncGraph). concrete_function.graph

Signature of concrete function can be accessed by concrete_function.inputs and concrete_function.outputs. Both of them are SymbolicTensor.

# f.get_concrete_function(1) returns a ConcreteFunction
concrete_function = f.get_concrete_function(tf.constant(1, dtype=tf.int32))

# inputs/outputs of concrete_function are SymbolicTensor (i.e. not eager tensor)
assert tf.is_symbolic_tensor(concrete_function.inputs[0])
assert tf.is_symbolic_tensor(concrete_function.outputs[0])

print(concrete_function.function_def)
# signature {
#   name: "__inference_f_51110"
#   input_arg {
#     name: "x"
#     type: DT_INT32
#   }
#   output_arg {
#     name: "identity"
#     type: DT_INT32
#   }
# }
# node_def {
#   name: "add/y"
#   op: "Const"
#   attr {
#     key: "dtype"
#     value {
#       type: DT_INT32
#     }
#   }
#   attr {
#     key: "value"
#     value {
#       tensor {
#         dtype: DT_INT32
#         tensor_shape {
#         }
#         int_val: 1
#       }
#     }
#   }
# }...
2.2.3.1. AtomicFunction (FuncGraph)

AtomicFunction wraps a FuncGraph in its cached graph, which is the actual graph in ConcreteFunction. The input to atomic function is typically user-inputs and captured inputs (e.g. variables)

AtomicFunction is callable and can be extracted by inference_fn

# AtomicFunction
concrete_function.inference_fn

# FuncGraph
graph = concrete_function.inference_fn.graph

# or simply
graph = concrete_function.graph

print(graph.as_graph_def())
# node {
#   name: "x"
#   op: "Placeholder"
#   attr {
#     key: "_user_specified_name"
#     value {
#       s: "x"
#     }
#   }
#   attr {
#     key: "dtype"
#     value {
#       type: DT_INT32
#     }
#   }
#   attr {
#     key: "shape"
#     value {
#       shape {
#       }
#     }
#   }
# }

print(graph.operations)
# [<tf.Operation 'x' type=Placeholder>,
#  <tf.Operation 'add/y' type=Const>,
#  <tf.Operation 'add' type=AddV2>,
#  <tf.Operation 'Identity' type=Identity>]

2.2.4. Captured Objects

Many objects are captured during tracing. concrete_function.captured_inputs are implicitly passed to the function as args next time after tracing.

See the following example:

a = tf.Variable([3.14, 3.14]) # captured as ResourceHandle (i.e just pointer)
b = tf.constant(1.0) # captured as actual tensor
c = 2. # built into graph's Add node attribute directly

@tf.function
def f(x):
  print('tracing...')
  d = tf.constant(3.) # built into graph's Const node
  return x + a + b + c + d

#<tf.Tensor: shape=(2,), dtype=float32, numpy=array([9.14, 9.14], dtype=float32)>
f(0.) # trigger tracing, 9.14 = 3.14 + 1 + 2 + 3

# [<tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="Variable/14", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [2] ]")>>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
# a, b are captured as inputs here 
f.get_concrete_function(0.).captured_inputs

# making any changes to a,b,c will not cause retracing or impact on the results, for example
b = tf.constant(2.0)
f(0.) # no re-tracing, result are still 9.14
f.get_concrete_function(0.).captured_inputs # 2nd captured input is still 1.0

# make assign can impact result but without retracing, this is because the captured_input for variable is ResourceHandle (pointer) not the actual values
a.assign([4.14, 4.14])
f(0.) # no-retracing, 10.14 = 4.14 + 1 + 2 + 3

2.3. Transformation

2.3.1. Placer

2.3.2. Grappler

2.3.3. Partitioner

Partitioner split the graph on different device and communication primitive (e.g. send/recv) is inserted to transfer data.

Actual communication should implement rendezvous interface. See this blog

2.4. Compile

2.4.1. StableHLO

Each traced ConcreteFunction (tf.graph) can be compiled when the shape changes, this is because XLA needs static shape but ConcreteFunction allows dynamic shape None, for example, batch size.

Note that compiling can happen even without retracing. see the following example:

@tf.function(input_signature=[tf.TensorSpec([1, None])], jit_compile=True)
def f(t):
  print("tracing...")
  return tf.sin(t)

# tracing...
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,1]{1,0})->f32[1,1]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,]]))(stage='hlo')

# not triggering tracing, but trigging compile. notice the shape of args has changed
# HloModule a_inference_f_2574__.7, entry_computation_layout={(f32[1,2]{1,0})->f32[1,2]{1,0}}...
f.experimental_get_compiler_ir(tf.constant([[1.,2]]))(stage='hlo')

Compiled results will also get cached

2.4.2. XLA Compile

StableHLO to native code

2.5. Data

2.5.1. Example proto

A single datapoint can be represented using tf.train.Example proto, which serializes a single \((x,y)\) datapoint into binary format.

An Example essentially represent the following structure:

Dict[str,
     Union[List[bytes],
           List[int64],
           List[float]]]

where it maps feature name (string key) to corresponding value tf.train.Feature which can be a list of int, float, byte. The proto def can be seen here

message Example {
  Features features = 1;
}

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
}

// Containers for non-sequential data.
message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
}

message BytesList {
  repeated bytes value = 1;
}
2.5.1.1. Native Conversion

This section lists a few Example related conversion

To convert native/numpy types into tf.train.Feature, we can use the following snippet

# native to feature
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

feature = _int64_feature(1)

# feature to native
feature.int64_list.value
To convert to and from example is similar

feature = {
    'feature0': _int64_feature(feature0),
    'feature1': _int64_feature(feature1),
    'feature2': _bytes_feature(feature2),
    'feature3': _float_feature(feature3),
}

# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))

2.5.2. tf.io package

Each special byte sequence should be using its own encoder/decoder, tf.io package provides those features

tensor conversion to convert tensor type into tf.train.Feature, we can do the following

# convert to bytelist tensor
t = tf.constant(1)
serialized_tensor = tf.io.serialize_tensor(t)

feature_of_bytes = tf.train.Feature(
  bytes_list=tf.train.BytesList(value=[serialized_tensor.numpy()]))

# convert back, type needed to be provided and matched
tf.io.parse_tensor(serialized_tensor, tf.int32)

example conversion

tf.io.parse_example

2.5.3. tf.data

tf.data.Dataset API for input pipelines

2.5.3.1. Naive Dataset

use from_tensors for a single datapoint and from_tensor_slices for multiple datapoints

# tf.data.Dataset.from_tensor_slices((X,y)) or tf.data.Dataset.from_tensor_slices(X)
dataset = tf.data.Dataset.from_tensor_slices(([np.array([1,2,3,4]), np.array([0,0,1,0])], np.array([1,0])))

# <_TensorSliceDataset element_spec=(TensorSpec(shape=(4,), dtype=tf.int64, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))>
print(dataset)


# (<tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 2, 3, 4])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
#(<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 1, 0])>, <tf.Tensor: shape=(), dtype=int64, numpy=0>)
for elem in iter(dataset):
    print(tf.data)

use from_generator to consume a python generator

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
2.5.3.2. Batching

use batch or padded_batch

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1

ds_series = tf.data.Dataset.from_generator(
    gen_series,
    output_types=(tf.int32, tf.float32),
    output_shapes=((), (None,)))

# batch 2 datapoint with max shape 8
# if each sample is a dict, the 2nd arg padded_shape can take a dict mapping each key to its max length or [] to indicate list without padding
ds_series_batch = ds_series.shuffle(20).padded_batch(2, 8)

_, sequence_batch = next(iter(ds_series_batch))
print(sequence_batch.numpy())
[[ 0.0208  0.      0.      0.      0.      0.      0.      0.],
 [-1.2131  0.523   1.083   0.3762 -1.1041 -1.6604 -2.3436, 0 ]]
2.5.3.3. Tokenize

can use tokenizer in tensorflow_text

tokenizer = text.WhitespaceTokenizer()
dataset = dataset.map(lambda x: tokenizer.tokenize(x))
2.5.3.4. TFRecordDataset

an Example proto can be serialized/deserialized into/from string

# convert example to string
string_example = example_proto.SerializeToString()

# convert back
example_proto = tf.train.Example.FromString(serialized_example.numpy())

TFRecord is a file format storing a sequence of byte sequences, it can be used to store example strings.

# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example.numpy())


# Read
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)

2.5.4. tfds

See doc here

tfds provides a ready-to-use datasets

tfds.load is roughly equivalently doing the following:

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)

2.6. Distribution

All tensorflow's visible devices can be retrieved with

# this typically includes CPU as well
logging.info("All TF devices: %s", tf.config.list_logical_devices())

2.7. Serialization

There are two related concepts:

  • checkpoint: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.
  • SavedModel: includes a serialized description of the computation at the tf.Graph level in addition to the parameter values (checkpoint)

2.7.1. Save

In Tf2, graph can be stored as a SavedModel format, in saved_model.pb, graph is serialized using MetaGraphDef protobuf (it contains GraphDef as a child). This in general encapsulate tf.Graph level information (i.e. not at the StableHLO level or compiled machine code level)

SavedModel typically saved a Trackable obj (typically tf.Module). SavedModel will save

  • signature (alias of ConcreteFunction): tf.function decorated with explicit input_signature will be saved, or target signatures should be passed as an argument.
  • checkpoint (recursive attributes of trackable objects): saved tensor variables
class Adder(tf.Module):

  def __init__(self):
    self.weight = tf.Variable(3.14, dtype=tf.float32)

  @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
  def __call__(self, x):
    return x + self.weight

model = Adder()
tf.saved_model.save(adder, "/tmp/adder")

This will create the following files:

  • saved_model.pb: it is a MetaGraphDef protobuf file, which contains GraphDef as a child. Typically a small file.

  • variables: checkpoint directory, it contain files such as variables.index, variables.data-00000-of-00001. See the next section

2.7.2. Checkpoint

Checkpoint typically has two files (ref: stackoverflow)

  • index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which "data" file contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
  • data file: it is TensorBundle collection, save the values of all variables.

To inspect variables (ckpt) from a SavedModel

f = tf.saved_model.load("/tmp/adder")

# trackable attributes are accessible after loading
# this info is stored in MetaGraphDef.object_graph_def

# <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>
f.weight

# it is also accessible from signature.variables
signature = f.signatures['serving_default']

# (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.14>,)
print(signature.variables)

# variables can also be loaded with load_checkpoint
reader = tf.train.load_checkpoint("/tmp/adder/variables/variables")

# {'_CHECKPOINTABLE_OBJECT_GRAPH': tf.string, 'weight/.ATTRIBUTES/VARIABLE_VALUE': tf.float32}
dtype_from_key = reader.get_variable_to_dtype_map()
shape_from_key = reader.get_variable_to_shape_map()

key = 'weight/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
reader.get_tensor(key)

Note that only Trackable objects and attributes will be saved, none-trackable tf objects (e.g. tf.constant) are saved in concrete_function.captured_input, native python is directly built into the model

It is also possible to save/restore a checkpoint without using SavedModel

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

2.7.3. Signature (ConcreteFunction)

To load a SavedModel and use its concrete functions:

f = tf.saved_model.load("/tmp/adder")

# use its signature (concrete function)
saved_concrete_function = f.signatures['serving_default']

# <ConcreteFunction (*, x: TensorSpec(shape=(), dtype=tf.float32, name='x')) -> Dict[['output_0', TensorSpec(shape=(), dtype=tf.float32, name='output_0')]] at ...>
print(saved_concrete_function)


# {'output_0': <tf.Tensor: shape=(), dtype=float32, numpy=4.1400003>}
saved_concrete_function(tf.constant(1.0))

# error: concrete function cannot run over incompatible input_signature
saved_concrete_function(tf.constant(1))

To inspect nodes of a graph from a signature

graph = f.signatures['serving_default'].graph
for node in graph.as_graph_def().node:
  print(node)

Note the graph before saving and the graph after loading is not identical. Compare the following two graph_def. Saved graph typically encapsulates most computing into a single raw_ops.PartitionedCall refering to some __inference___call___ function in the function library. See its doc

class Adder(tf.Module):

  @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
  def __call__(self, x):
    return x + 3.14

# before saving
adder = Adder()
adder.__call__.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()

# after saving
tf.saved_model.save(adder, "/tmp/adder")
saved_adder = tf.saved_model.load("/tmp/adder")
saved_adder_concrete_function = saved_adder.signatures['serving_default']
saved_adder_concrete_function.graph.as_graph_def()

Marking jit_compile=True will produce the same graph before saving, but will attach _XlaMustCompile=True attribute to PartitionedCall during saving. This probably will trigger jit after loading

3. Jax

Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.

Here is a great blog comparing static/dynamic autodifferentiation library

The execution path is roughly

\[\text{python} \to[\text{trace}]\to \text{jaxpr} \to[\text{lower}]\to \text{StableHLO} \to[\text{xla compile}]\to \text{native}\]

3.1. API

3.1.1. lax

lax is a lib of primitive operations underlying jnp, it is a thin wrapper over XLA and we know how to transform

See this section about lax control flow

3.1.2. primitive

we can define new primitives and implement its evaluation, transformation methods

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

def multiply_add_prim(x, y, z):
  return multiply_add_p.bind(x, y, z)
3.1.2.1. eval

primal eval should be implemented

def multiply_add_impl(x, y, z):
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)

To enable tracing, abstract eval should be implemented to inform how to infer shape and type

def multiply_add_abstract_eval(xs, ys, zs):
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

# with abstract evaluation, we can tracing it correctly
# { lambda ; a:f32[] b:f32[] c:f32[]. let d:f32[] = multiply_add a b c in (d,) }
api.make_jaxpr(multiply_add_prim)(2.0, 3.0, 10.)
3.1.2.2. differentiation

define jvp and transpose

3.1.3. jnp

similar to the numpy syntax

3.1.3.1. packing and segments

sometimes we want to do packing to efficiently use TPU.

Here is a example illustrating packing from tensor2tensor library. It is concatenating sentence 'horizontally' and tracking sentence by a segment id.

It can be combined with batching as well (batching is to concatenate padded sequence 'vertically')

Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
The output example is:
{
                "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
  "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
      "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
              "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
  "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
      "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}
3.1.3.2. paddings and masking

some paddings related code snippets

# create padding from ids
# seq = [1,2,3,0,0]
# padding = [0, 0, 0, 1, 1]
jnp.equal(seq, 0)

# apply paddings
masked_features = features * (1.0-paddings[:, :, jnp.newaxis])

# padding to length
seq_lengths = jnp.sum(1.0 - paddings, axis=-1)

# length to paddings
# DeviceArray([[ True,  True,  True,  True],
#             [ True,  True, False, False],
#             [ True, False, False, False]], dtype=bool)
lengths = jnp.array([4,2,1])
col_ids = jnp.arange(4)[jnp.newaxis, :]
col_ids < length[:, jnp.newaxis]


# Look-ahead mask or causal mask
def create_look_ahead_mask(size: int) -> np.ndarray:
  mask = np.ones([size, size])
  mask[np.tril_indices(size)] = 0.0
  return mask  # (seq_len, seq_len)


# (array([0, 1, 1]), array([0, 0, 1]))
np.tril_indices(2)

#array([[0., 1., 1.],
#       [0., 0., 1.],
#       [0., 0., 0.]])
create_look_ahead_mask(3)


def another_look_ahead_mask(size: int)
  # causal mask using col/row id comparison
  col_idx = jnp.tile(jnp.arange(size)[jnp.newaxis, :], [size, 1])
  row_idx = jnp.tile(jnp.arange(size)[:, jnp.newaxis], [1, size])
  mask = (row_idx < col_idx)
  return mask


# packed segment masking
# [B,T,1]
fst_segment_ids = segment_ids[:, :, jnp.newaxis]

# [B,1,T]
snd_segment_ids = segment_ids[;, jnp.newaxis, :]

jnp.not_equal(fst_segment_ids, snd_segment_ids)

3.2. Tracing

Different transformation requires different level of abstraction. They might have some issues depending on the control flow, see doc here

Target function should contain args/return of arrays, scalars or nested containers. Other values (e.g. str) will throw error during tracing

Note that tracing might happen by just using jnp as some jnp is implemented with jit, for example,

x = jnp.zeros((3,3))
# WARNING:jax._src.dispatch:Finished tracing + transforming broadcast_in_dim for pjit in 0.003103494644165039 sec
# WARNING:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
#WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.005272388458251953 sec
#WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.016413211822509766 sec

3.2.1. Tracing Abstraction

By default, tracing is done using ShapedArray abstraction level, which might be problematic when hitting control flow

Using static_argnums and static_argnames force tracer to trace with concrete value instead of ShapedArray they are considered compile-time constant, which will be constant-folded during tracing.

Changing those statics will trigger recompilation as it changes signature. Whether statics have changed or not depend on the __hash__ and __eq__ (statics must be hashable)

@functools.partial(jax.jit, static_argnums=1)
def f(x, a):

  # during tracing, this shows 1 as it is static
  print(a) 
  return jnp.sin(x)

f(jnp.zeros(1), 1) # this will compile
f(jnp.zeros(1), 2) # this will compile again as it is traced with actual number and signature change

@jax.jit
def g(x, a):

  # in tracing, Traced<ShapedArray(int32[], weak_type=True)
  print(a)
  return jnp.sin(x)

g(jnp.zeros(1), 1) # this will compile
g(jnp.zeros(1), 2) # this will not re-compile as its traced with ShapedArray(int32[]) and its signature does not change

3.2.2. jaxpr

Jaxpr is a sequence of primitives equations which are obtained by tracing, and this tracing process can be inspected using make_jaxpr

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

#foo
#=====
#invars: [a]
#outvars: [b]
#constvars: []
#equation: [a, 1] add [b] {}

#jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

3.3. Transformation

See this doc to understand how to add a new transformation

many Jax transformation are probably done on top of jaxpr, it (mostly) transforms jaxpr into another jaxpr and works as an interpreter.

3.3.1. grad

jax uses jvp to obtain a forward jaxpr, then applying transpose to each expression of the forward jaxpr to obtain grad

Consider the function \(f(x,y) = xy+y\), (during grad) jvp produces the following expression where xt, yt are traced abstractly

a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt

transpose then interpets them from backward and produce the following expressions:

# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.

3.4. Compile

3.4.1. lower

jaxpr which can be lowered into HLO (XLA's IR). Notice this is only translation, not optimization.

@jax.jit
def one_plus_one():
  a = jnp.ones(1)
  b = jnp.ones(1)
  return a + b

print(one_plus_one.lower().as_text())
# module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
#   func.func public @main() -> (tensor<1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
#     %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
#     %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1xf32>
#     %2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
#     %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
#     %4 = stablehlo.add %1, %3 : tensor<1xf32>
#     return %4 : tensor<1xf32>
#   }
# }

lowering is achieved by implementing lowering rule for each primitive

def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

3.4.2. compile

HLO can be futher lowered into executable by compiling using XLA. Most of the optimization happen at this stage.

compiled = jax.jit(lambda x: x + 2).lower(3).compile()

# HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[1]{0:T(256)}}, allow_spmd_sharding_propagation_to_output={true}

# ENTRY %main.2 () -> f32[1] {
#   %constant.1 = f32[1]{0:T(256)} constant({2})
#   ROOT %copy.1 = f32[1]{0:T(256)} copy(f32[1]{0:T(256)} %constant.1)
# }
print(one_plus_one.lower().compile().as_text()) # notice that const 1 + 1 is optimized into 2 directly

3.4.3. jit

jit decorator combines all the previous steps: tracing, transforming, lowering and compiling.

3.4.4. jax2tf

See jax2tf doc

from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

def f_jax(x):
  return jnp.sin(jnp.cos(x))

# jax2tf.convert is a higher-order function that returns a wrapped function with
# the same signature as your input function but accepting TensorFlow tensors (or
# variables) as input.
f_tf = jax2tf.convert(f_jax)

# For example you execute f_tf eagerly with valid TensorFlow inputs:
f_tf(np.random.random(...))

# Additionally you can use tools like `tf.function` to improve the execution
# time of your function, or to stage it out to a SavedModel:
f_tf_graph = tf.function(f_tf, autograph=False)

# this triggers the jaxpr tracing, built StableHLO and encapsulate the StableHLO into tf.XlaCallModule ops. all StableHLO information is serialized into string as a module attribute of XlaCallModule node.
f_tf_graph.get_concrete_function(tf.constant(1.0)).graph.as_graph_def()

3.5. Distribution

3.5.1. Sharding

See jax distribution doc, jax.sharding doc, also check this kaggle blog

There are a few Sharding, typically a XLACompatibleSharding. It only specify how data get sharded, but it does not represent the sharded data itself.

SingleDeviceSharding a sharding that places its data on a single device (default sharding)

# tensor before placement is assigned to this default sharding
# it is placed on one of the device
x = jnp.zeros((8,16))

# SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
x.sharding

NamedSharding is a tuple of Mesh and ParitionSpec, it can be created manually as follows:

from jax.sharding import Mesh
from jax.sharding import PartitionSpec

# create a mesh
devices = np.array(jax.devices()).reshape(4, 2)

# Declare a 2D mesh with axes `x` and `y`.
mesh = Mesh(devices, ('x', 'y'))

# partitionspec describes how each input dim can be shareded using mesh axis
# e.g. PartitionSpec('x', 'y') says that the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh
# its argc should be equal to the tensor dim size which it is going to shard
spec = PartitionSpec('x', 'y')

# combine mesh and partitionspec together into a namedsharding
named_sharding = jax.sharding.NamedSharding(mesh, spec)

# shard the data and visualize
data = jnp.arange(8).reshape(2,4)
sharded_data = jax.device_put(data, named_sharding)

# this sharding info can be found in the array's attribute
print(sharded_data.sharding)
jax.debug.visualize_array_sharding(sharded_data)

┌──────────┬──────────┬──────────┬──────────┐
                                        
  TPU 0     TPU 1     TPU 2     TPU 3   
                                        
                                        
├──────────┼──────────┼──────────┼──────────┤
                                        
  TPU 4     TPU 5     TPU 6     TPU 7   
                                        
                                        
└──────────┴──────────┴──────────┴──────────┘

# another sharding
named_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
sharded_data = jax.device_put(data, named_sharding)
jax.debug.visualize_array_sharding(sharded_data)

┌────────────────────────────────────────────────┐
                                                
                  TPU 0,1,2,3                   
                                                
                                                
├────────────────────────────────────────────────┤
                                                
                  TPU 4,5,6,7                   
                                                
                                                
└────────────────────────────────────────────────┘

To control the device placement,

  • for input data, use jax.device_put to actually shard data with a given Sharding
  • for intermediate result within jit decorated function, use jax.lax.with_sharding_constraint to force the placement
mesh = Mesh(np.array(jax.devices()).reshape(2, 1), ('x', 'y'))
spec = P('x', 'y')
named_sharding = jax.sharding.NamedSharding(mesh, spec)

data = jnp.array([[1,2],[3,4]])

# Array([[1, 2],[3, 4]], dtype=int32)
sharded_data = jax.device_put(data, named_sharding)

# NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=tpu_hbm)
print(sharded_data.sharding)

# resharding using another spec
newspec = P('y', 'x')
new_sharding = jax.sharding.NamedSharding(mesh, newspec)
new_sharded_data = jax.device_put(sharded_data, new_sharding)

# resharding in jit decoration
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y

3.5.2. vmap

vmap vectorize map, for example

import jax.numpy as jnp

vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

3.5.3. pmap

pmap is the user-based SPMD programming

Applying pmap to a function will compile the function with XLA, then replicates the function to executes each replica on its own XLA device in parallel.

# the splitting dim should not be larger than n_devices
n_devices = jax.local_device_count()

# create a sharded array
a = jax.pmap(lambda x: x**2)(jnp.arange(n_devices))

# shardingspec
# PmapSharding(sharding_spec=ShardingSpec((Unstacked(2),), (ShardedAxis(axis=0),)), device_ids=[0, 1], device_platform=TPU, device_shape=(2,))
print(a.sharding)

# get info of each shard
jax.debug.visualize_array_sharding(a)
for i, shard in enumerate(a.global_shards):
    print(f"\nShard no: {i:>5}")
    print(f"Device: {str(shard.device):>32}")
    print(f"Data shape: {str(shard.data.shape):>8}")
    print(f"Data slices: {str(shard.index):>22}\n")
    print("="*75)
    print("")

3.5.4. jit (pjit)

jit is the Automatic SPMD programming where internal sharding is automatically propagated based on input/output sharding

The syntax is as follows:

pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue,...)

  • in_sharding: optional as jax will infer the sharding from the input, when inference cannot be done, it will be default to replication (i.e. None)
  • out_sharding: also optional, hwne not specified, sharding will be inferred from GSPMD's sharding propagation.

3.6. flax

Since jax transformation is functional, linen is designed to do stateless functional transform as the following where \(v\) are variable collection, \(x,y\) are input output data.

\[v_{out}, y = f(v_{in}, x)\]

3.6.1. linen

flax.linen.Module is dataclass-like class, it only has hyperparameter members,(i.e. parameters are not stored inside)

from flax import linen as nn
from typing import Tuple

class Module(nn.Module):
  # this hyp becomes an instance variable and available through self.features immediately
  features: Tuple[int, ...] = (16, 4)

  # setup will not be called during init, it is called to bind scope EVERYTIME when calling init or apply
  def setup(self):
    # can only be accessed in apply, init. not directly from the object
    self.dense1 = nn.Dense(self.features[0])
    self.dense2 = nn.Dense(self.features[1])

  # it get called after binding with setup
  def __call__(self, x):
    return self.dense2(nn.relu(self.dense1(x)))

variable is a container of "variable collections"

{
  "params": {
    "Conv1": { "weight": ..., "bias": ... },
    "BatchNorm1": { "scale": ..., "mean": ... },
    "Conv2": {...}
  },
  "batch_stats": {
    "BatchNorm1": { "moving_mean": ..., "moving_average": ...}
  }
}

param is a special class of variable. It is immutable. See this post

# this will retrive the value (if exists) or initialize it and return it (if not exists)
p = self.param('param_name', init_fn, shape, dtype)

# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value

init/apply

init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params

See the lifecycle doc for more details

model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input

# init
params = model.init(key2, x)

# apply
y = model.apply(params, x)

# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])

3.6.2. serialization

usage example

from flax import serialization

# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

# load, note that  a template is needed to reconstruct from bytes
serialization.from_bytes(params, bytes_output)

3.6.3. Rematerialization

rematerialization/checkpointing feature is supported in jax, see this doc

3.7. optax

Optimizer

usage example

import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)