Skip to content

0x424 Interface

Mostly my notes of high-level interface for several frameworks

The character notations for common tensors are

  • B: batch size
  • T: time step
  • H: hidden vector
  • A: attention size
  • C: channel
  • L: layer size

1. Pytorch

1.1. Tensor

1.1.1. Interface

1.1.1.1. Initialization
## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))

## long long (torch.long)
torch.LongTensor([1,2,3])

## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)


## double (torch.double)
torch.DoubleTensor([1,2,3])
1.1.1.2. Operation
# element-wise operation 
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)

# (2,3,4)
(a*a).shape

# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)

# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))

## bmm: batched version of mm

## matmul: lots of cases...
# following examples are from pytorch docs

# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])


# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication

>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

1.1.1.3. Masking

Masking is important. Without proper masking, attention will not work. Recall my own training failure...

1.1.1.3.1. Length Masking

Translation between mask and length

# length to mask
length = torch.LongTensor([3,5,4])


# tensor([[ True,  True,  True, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True,  True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)

# to get reversed mask, use ~mask

# mask to length
length = torch.sum(mask, dim=1)

apply mask

# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])

b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101,  0.5982,    -inf,    -inf],
#        [ 0.6641, -0.8612,  0.2389, -1.1343,  1.1640],
#        [ 1.1004, -1.2243, -1.1901, -0.5063,    -inf]])
1.1.1.4. Triangular Masking

for transformer decoder

>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])
1.1.1.5. Padding

To pad, see the following code

# padding
 a = [torch.tensor([1,2,3]), torch.tensor([3,4])]

# tensor([[ 1,  2,  3],
#   [ 3,  4,  0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)

# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
    """Perform padding for the list of tensors.

    Args:
        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
        pad_value (float): Value for padding.

    Returns:
        Tensor: Padded tensor (B, Tmax, `*`).

    Examples:
        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
        >>> x
        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
        >>> pad_list(x, 0)
        tensor([[1., 1., 1., 1.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.]])

    """
    n_batch = len(xs)

    if max_len == -1:
        max_len = max(x.size(0) for x in xs)

    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)

    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]

    return pad

To unpad

# get length
a = tensor([[ 1,  2,  3],
            [ 3,  4,  0]])

mask = (a!=0)
length = torch.sum(mask, dim=1)

1.2. Modules

C binding implementation is here

1.2.1. Implementation

Source code is available here

It looks each module object manage the following members:

  • _modules: submodules
  • _parameters: tensor with gradient, these are to be optimizer with optim
  • _buffers: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm

state_dict retrieves both parameters and buffers

When we construct a new member in init, it did the registration automatically by overwriting the __setattr__:

    # a simplified version of setattr from pytorch
    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

        # register a parameter
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            self.register_parameter(name, value)
        else:

            # register a module
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                modules[name] = value

Those members would be later __getattr__ by checking against __modules, _parameters, _buffers explicitly.

1.2.2. LSTM

interface: nn.LSTM(input_size, hidden_size, num_layers, batch_first=False, bidirectional=False)

input: input, (h0, c0)

  • input: (B,T,H) when batch first
  • (h0, c0): (Bidirection*L, B, H)

output: output, (hn, cn)

  • output: (B,T,Bidrection*H)
  • (hn, cn): (Bidirection*L, B, H)
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

1.2.3. CNN

interface: nn.Conv1d(in_channels, out_channels, kernel_size)

input: (B,Ci,Ti) output: (B,Co,To)

>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)

1.3. Loss

1.4. Optimization

1.5. Distribution

2. Tensorflow

2.1. Tensor

tf.Tensor is immutable tensor. A number of specialized tensors are available: see tf.Variable, tf.constant, tf.placeholder and tf.RaggedTensor

tf.Variable is mutable tensor and requires an initial value for the variable and maintains states during execution.

v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)

tf.Module is a named container for tf.Variables, it has variables, trainable_variables and submodules.

2.2. Graph

A tf.Graph is the raw, language-agnostic, portable representation of a TensorFlow computation

2.2.1. Tensorflow 1

A tf.Session stores the state of a graph (i.e. the values of the variables).

A session may own resources, such as tf.Variable, tf.queue.QueueBase, and tf.compat.v1.ReaderBase

In Tf1, graph is serialized using GraphDef protobuf, here is the doc, it is saved using tf.compat.v1.train.Saver

It is serialized inside the meta file

2.2.2. Tensorflow 2

2.2.2.1. tf.Function

tf.Graph is not enabled by default in eager mode, tf2 builds a graph using tf.function decorator, see doc here

The tf.function work in the following steps:

  • tf.function wraps a Python function, returning a `GenericFunction`` object.
  • A Function manages a cache of `ConcreteFunctions`` and picks the right one for your inputs.
  • A ConcreteFunction wraps a tf.Graph.

Tracing creates a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)

In Tf2, graph is stored as a SavedModel format, in saved_model.pt, graph is serialized using MetaGraphDef protobuf (it contains GraphDef as a child)

To inspect nodes of a graph, use the following snippet

import tensorflow.compat.v2 as tf

model = tf.saved_model.load('path_to_saved_model')

graph = model.graph
for node in graph.as_graph_def().node:
  print(node)

2.2.3. Operation (Node)

An tf.Operation is a node in a tf.Graph that takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output

See the list of raw_ops here

See this doc for control-flow ops implementation

2.2.4. Edge

Edge represents dependency, it links node's output to another node's input

2.2.5. Function

subroutines in the graph, it is graph itself consisting of other ops

2.3. Checkpoint

There are two related concepts:

  • checkpoint: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.
  • SavedModel: includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint)

Checkpoint typically has two files (ref: stackoverflow)

  • index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
  • data file: it is TensorBundle collection, save the values of all variables.

To save/restore a checkpoint

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

To inspect a checkpoint:

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)

reader.get_tensor(key)

3. Jax

Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.

Here is a great blog comparing static/dynamic autodifferentiation library

3.1. lax

lax is a wrapper over XLA

3.1.1. Arithmetic

3.1.2. Control Flow

See this section

3.2. jaxpr

3.3. jnp

similar to the numpy syntax

3.3.1. packing and segments

sometimes we want to do packing to efficiently use TPU.

Here is a example illustrating packing from tensor2tensor library. It is concatenating sentence 'horizontally' and tracking sentence by a segment id.

It can be combined with batching as well (batching is to concatenate padded sequence 'vertically')

Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
The output example is:
{
                "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
  "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
      "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
              "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
  "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
      "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}

3.3.2. paddings and masking

some paddings related code snippets

# create padding from ids
# seq = [1,2,3,0,0]
# padding = [0, 0, 0, 1, 1]
jnp.equal(seq, 0)

# apply paddings
masked_features = features * (1.0-paddings[:, :, jnp.newaxis])

# padding to length
seq_lengths = jnp.sum(1.0 - paddings, axis=-1)

# length to paddings
# DeviceArray([[ True,  True,  True,  True],
#             [ True,  True, False, False],
#             [ True, False, False, False]], dtype=bool)
lengths = jnp.array([4,2,1])
col_ids = jnp.arange(4)[jnp.newaxis, :]
col_ids < length[:, jnp.newaxis]


# Look-ahead mask or causal mask
def create_look_ahead_mask(size: int) -> np.ndarray:
  mask = np.ones([size, size])
  mask[np.tril_indices(size)] = 0.0
  return mask  # (seq_len, seq_len)


# (array([0, 1, 1]), array([0, 0, 1]))
np.tril_indices(2)

#array([[0., 1., 1.],
#       [0., 0., 1.],
#       [0., 0., 0.]])
create_look_ahead_mask(3)


def another_look_ahead_mask(size: int)
  # causal mask using col/row id comparison
  col_idx = jnp.tile(jnp.arange(size)[jnp.newaxis, :], [size, 1])
  row_idx = jnp.tile(jnp.arange(size)[:, jnp.newaxis], [1, size])
  mask = (row_idx < col_idx)
  return mask


# packed segment masking
# [B,T,1]
fst_segment_ids = segment_ids[:, :, jnp.newaxis]

# [B,1,T]
snd_segment_ids = segment_ids[;, jnp.newaxis, :]

jnp.not_equal(fst_segment_ids, snd_segment_ids)

3.4. Frameworks

3.4.1. flax

This colab covers most oof the basics

3.4.1.1. linen

See this linen documentation

It is designed to do stateless functional transform as the following where \(v\) are variables.

\[v_{out}, y = f(v_{in}, x)\]

class (flax.linen.Module)

Param vs Variable

Module.param is a special class of Module.variable. It is immutable. See this post

p = self.param('param_name', init_fn, shape, dtype)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value

method init/apply

initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params

model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input

# init
params = model.init(key2, x)

# apply
y = model.apply(params, x)

# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])
3.4.1.2. serialization

usage example

from flax import serialization

# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

# load
serialization.from_bytes(params, bytes_output)

3.4.2. optax

usage example

import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

3.4.3. pax

Paxml

3.4.4. praxis

pax layer library