# 0x424 Interface

Mostly my notes of high-level interface for several frameworks

The character notations for common tensors are

• B: batch size
• T: time step
• H: hidden vector
• A: attention size
• C: channel
• L: layer size

## 1. Pytorch

### 1.1. Tensor

#### 1.1.1. Interface

##### 1.1.1.1. Initialization
## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))

## long long (torch.long)
torch.LongTensor([1,2,3])

## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)

## double (torch.double)
torch.DoubleTensor([1,2,3])

##### 1.1.1.2. Operation
# element-wise operation
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)

# (2,3,4)
(a*a).shape

# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)

# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))

## bmm: batched version of mm

## matmul: lots of cases...
# following examples are from pytorch docs

# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])

# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication

>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
￼


Masking is important. Without proper masking, attention will not work. Recall my own training failure...

# length to mask
length = torch.LongTensor([3,5,4])

# tensor([[ True,  True,  True, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True,  True, False]])



# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True,  True,  True, False, False],
[ True,  True,  True,  True,  True],
[ True,  True,  True,  True, False]])

#tensor([[ 1.3503, -0.2101,  0.5982,    -inf,    -inf],
#        [ 0.6641, -0.8612,  0.2389, -1.1343,  1.1640],
#        [ 1.1004, -1.2243, -1.1901, -0.5063,    -inf]])


for transformer decoder

>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
[ True,  True, False],
[ True,  True,  True]])


To pad, see the following code

# padding
a = [torch.tensor([1,2,3]), torch.tensor([3,4])]

# tensor([[ 1,  2,  3],
#   [ 3,  4,  0]])

# for more complicated case, use the following snippet
"""Perform padding for the list of tensors.

Args:
xs (List): List of Tensors [(T_1, *), (T_2, *), ..., (T_B, *)].

Returns:
Tensor: Padded tensor (B, Tmax, *).

Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])

"""
n_batch = len(xs)

if max_len == -1:
max_len = max(x.size(0) for x in xs)

for i in range(n_batch):



# get length
a = tensor([[ 1,  2,  3],
[ 3,  4,  0]])



### 1.2. Modules

C binding implementation is here

#### 1.2.1. Implementation

Source code is available here

It looks each module object manage the following members:

• _modules: submodules
• _parameters: tensor with gradient, these are to be optimizer with optim
• _buffers: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm

state_dict retrieves both parameters and buffers

When we construct a new member in init, it did the registration automatically by overwriting the __setattr__:

    # a simplified version of setattr from pytorch
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

# register a parameter
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
self.register_parameter(name, value)
else:

# register a module
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
modules[name] = value


Those members would be later __getattr__ by checking against __modules, _parameters, _buffers explicitly.

#### 1.2.2. LSTM

interface: nn.LSTM(input_size, hidden_size, num_layers, batch_first=False, bidirectional=False)

input: input, (h0, c0)

• input: (B,T,H) when batch first
• (h0, c0): (Bidirection*L, B, H)

output: output, (hn, cn)

• output: (B,T,Bidrection*H)
• (hn, cn): (Bidirection*L, B, H)
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))


#### 1.2.3. CNN

interface: nn.Conv1d(in_channels, out_channels, kernel_size)

input: (B,Ci,Ti) output: (B,Co,To)

>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)


## 2. Tensorflow

### 2.1. Tensor

tf.Tensor is immutable tensor. A number of specialized tensors are available: see tf.Variable, tf.constant, tf.placeholder and tf.RaggedTensor

tf.Variable is mutable tensor and requires an initial value for the variable and maintains states during execution.

v = tf.Variable(1.)
v.assign(2.)


tf.Module is a named container for tf.Variables, it has variables, trainable_variables and submodules.

### 2.2. Graph

A tf.Graph is the raw, language-agnostic, portable representation of a TensorFlow computation

#### 2.2.1. Tensorflow 1

A tf.Session stores the state of a graph (i.e. the values of the variables).

A session may own resources, such as tf.Variable, tf.queue.QueueBase, and tf.compat.v1.ReaderBase

In Tf1, graph is serialized using GraphDef protobuf, here is the doc, it is saved using tf.compat.v1.train.Saver

It is serialized inside the meta file

#### 2.2.2. Tensorflow 2

##### 2.2.2.1. tf.Function

tf.Graph is not enabled by default in eager mode, tf2 builds a graph using tf.function decorator, see doc here

The tf.function work in the following steps:

• tf.function wraps a Python function, returning a GenericFunction object.
• A Function manages a cache of ConcreteFunctions and picks the right one for your inputs.
• A ConcreteFunction wraps a tf.Graph.

Tracing creates a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.

@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)


In Tf2, graph is stored as a SavedModel format, in saved_model.pt, graph is serialized using MetaGraphDef protobuf (it contains GraphDef as a child)

To inspect nodes of a graph, use the following snippet

import tensorflow.compat.v2 as tf

graph = model.graph
for node in graph.as_graph_def().node:
print(node)


#### 2.2.3. Operation (Node)

An tf.Operation is a node in a tf.Graph that takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output

See the list of raw_ops here

See this doc for control-flow ops implementation

#### 2.2.4. Edge

Edge represents dependency, it links node's output to another node's input

#### 2.2.5. Function

subroutines in the graph, it is graph itself consisting of other ops

### 2.3. Checkpoint

There are two related concepts:

• checkpoint: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.
• SavedModel: includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint)

Checkpoint typically has two files (ref: stackoverflow)

• index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
• data file: it is TensorBundle collection, save the values of all variables.

To save/restore a checkpoint

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the model object.
checkpoint.restore(save_path)


To inspect a checkpoint:

reader = tf.train.load_checkpoint('./tf_ckpts/')

sorted(shape_from_key.keys())
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)



## 3. Jax

Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.

Here is a great blog comparing static/dynamic autodifferentiation library

### 3.1. lax

lax is a wrapper over XLA

See this section

### 3.3. jnp

similar to the numpy syntax

#### 3.3.1. packing and segments

sometimes we want to do packing to efficiently use TPU.

Here is a example illustrating packing from tensor2tensor library. It is concatenating sentence 'horizontally' and tracking sentence by a segment id.

It can be combined with batching as well (batching is to concatenate padded sequence 'vertically')

Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
The output example is:
{
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
"inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
"inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
"targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
"targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}


# create padding from ids
# seq = [1,2,3,0,0]
# padding = [0, 0, 0, 1, 1]
jnp.equal(seq, 0)

seq_lengths = jnp.sum(1.0 - paddings, axis=-1)

# DeviceArray([[ True,  True,  True,  True],
#             [ True,  True, False, False],
#             [ True, False, False, False]], dtype=bool)
lengths = jnp.array([4,2,1])
col_ids = jnp.arange(4)[jnp.newaxis, :]
col_ids < length[:, jnp.newaxis]

# (array([0, 1, 1]), array([0, 0, 1]))
np.tril_indices(2)

#array([[0., 1., 1.],
#       [0., 0., 1.],
#       [0., 0., 0.]])

# causal mask using col/row id comparison
col_idx = jnp.tile(jnp.arange(size)[jnp.newaxis, :], [size, 1])
row_idx = jnp.tile(jnp.arange(size)[:, jnp.newaxis], [1, size])

# [B,T,1]
fst_segment_ids = segment_ids[:, :, jnp.newaxis]

# [B,1,T]
snd_segment_ids = segment_ids[;, jnp.newaxis, :]

jnp.not_equal(fst_segment_ids, snd_segment_ids)


### 3.4. Frameworks

#### 3.4.1. flax

This colab covers most oof the basics

##### 3.4.1.1. linen

See this linen documentation

It is designed to do stateless functional transform as the following where $$v$$ are variables.

$v_{out}, y = f(v_{in}, x)$

class (flax.linen.Module)

Param vs Variable

Module.param is a special class of Module.variable. It is immutable. See this post

p = self.param('param_name', init_fn, shape, dtype)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value


method init/apply

initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params

model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input

# init
params = model.init(key2, x)

# apply
y = model.apply(params, x)

# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])

##### 3.4.1.2. serialization

usage example

from flax import serialization

# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

serialization.from_bytes(params, bytes_output)


#### 3.4.2. optax

usage example

import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)

for i in range(101):