0x424 Interface
Mostly my notes of high-level interface for several frameworks
The character notations for common tensors are
- B: batch size
- T: time step
- H: hidden vector
- A: attention size
- C: channel
- L: layer size
1. Pytorch
1.1. Tensor
1.1.1. Interface
1.1.1.1. Initialization
## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))
## long long (torch.long)
torch.LongTensor([1,2,3])
## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)
## double (torch.double)
torch.DoubleTensor([1,2,3])
1.1.1.2. Operation
# element-wise operation
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)
# (2,3,4)
(a*a).shape
# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)
# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))
## bmm: batched version of mm
## matmul: lots of cases...
# following examples are from pytorch docs
# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

1.1.1.3. Masking
Masking is important. Without proper masking, attention will not work. Recall my own training failure...
1.1.1.3.1. Length Masking
Translation between mask and length
# length to mask
length = torch.LongTensor([3,5,4])
# tensor([[ True, True, True, False, False],
# [ True, True, True, True, True],
# [ True, True, True, True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)
# to get reversed mask, use ~mask
# mask to length
length = torch.sum(mask, dim=1)
apply mask
# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True, True, True, False, False],
[ True, True, True, True, True],
[ True, True, True, True, False]])
b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101, 0.5982, -inf, -inf],
# [ 0.6641, -0.8612, 0.2389, -1.1343, 1.1640],
# [ 1.1004, -1.2243, -1.1901, -0.5063, -inf]])
1.1.1.4. Triangular Masking
for transformer decoder
>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
[ True, True, False],
[ True, True, True]])
1.1.1.5. Padding
To pad, see the following code
# padding
a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
# tensor([[ 1, 2, 3],
# [ 3, 4, 0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
if max_len == -1:
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
To unpad
# get length
a = tensor([[ 1, 2, 3],
[ 3, 4, 0]])
mask = (a!=0)
length = torch.sum(mask, dim=1)
1.2. Modules
C binding implementation is here
1.2.1. Implementation
Source code is available here
It looks each module object manage the following members:
_modules
: submodules_parameters
: tensor with gradient, these are to be optimizer with optim_buffers
: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm
state_dict
retrieves both parameters
and buffers
When we construct a new member in init, it did the registration automatically by overwriting the __setattr__
:
# a simplified version of setattr from pytorch
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
# register a parameter
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
self.register_parameter(name, value)
else:
# register a module
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
modules[name] = value
Those members would be later __getattr__
by checking against __modules
, _parameters
, _buffers
explicitly.
1.2.2. LSTM
interface: nn.LSTM(input_size, hidden_size, num_layers, batch_first=False, bidirectional=False)
input: input, (h0, c0)
- input: (B,T,H) when batch first
(h0, c0)
: (Bidirection*L, B, H)
output: output, (hn, cn)
output
: (B,T,Bidrection*H)(hn, cn)
: (Bidirection*L, B, H)
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
1.2.3. CNN
interface: nn.Conv1d(in_channels, out_channels, kernel_size)
input: (B,Ci,Ti)
output: (B,Co,To)
>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)
1.3. Loss
1.4. Optimization
1.5. Distribution
2. Tensorflow
2.1. Tensor
tf.Tensor
is immutable tensor. A number of specialized tensors are available: see tf.Variable
, tf.constant
, tf.placeholder
and tf.RaggedTensor
tf.Variable
is mutable tensor and requires an initial value for the variable and maintains states during execution.
v = tf.Variable(1.)
v.assign(2.)
v.assign_add(0.5)
tf.Module
is a named container for tf.Variables
, it has variables, trainable_variables and submodules.
2.2. Graph
A tf.Graph
is the raw, language-agnostic, portable representation of a TensorFlow computation
2.2.1. Tensorflow 1
A tf.Session
stores the state of a graph (i.e. the values of the variables).
A session may own resources, such as tf.Variable
, tf.queue.QueueBase
, and tf.compat.v1.ReaderBase
In Tf1, graph is serialized using GraphDef
protobuf, here is the doc, it is saved using tf.compat.v1.train.Saver
It is serialized inside the meta
file
2.2.2. Tensorflow 2
2.2.2.1. tf.Function
tf.Graph
is not enabled by default in eager mode, tf2 builds a graph using tf.function
decorator, see doc here
The tf.function
work in the following steps:
- tf.function wraps a Python function, returning a `GenericFunction`` object.
- A Function manages a cache of `ConcreteFunctions`` and picks the right one for your inputs.
- A
ConcreteFunction
wraps a tf.Graph.
Tracing creates a tf.Graph and wraps it in a ConcreteFunction, also known as a trace.
@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
In Tf2, graph is stored as a SavedModel format, in saved_model.pt
, graph is serialized using MetaGraphDef
protobuf (it contains GraphDef as a child)
To inspect nodes of a graph, use the following snippet
import tensorflow.compat.v2 as tf
model = tf.saved_model.load('path_to_saved_model')
graph = model.graph
for node in graph.as_graph_def().node:
print(node)
2.2.3. Operation (Node)
An tf.Operation
is a node in a tf.Graph
that takes zero or more Tensor objects as input, and produces zero or more Tensor objects as output
See the list of raw_ops here
See this doc for control-flow ops implementation
2.2.4. Edge
Edge represents dependency, it links node's output to another node's input
2.2.5. Function
subroutines in the graph, it is graph itself consisting of other ops
2.3. Checkpoint
There are two related concepts:
checkpoint
: contains only values of parameters, no graph is included. It is typically only useful when source code that will use the saved parameter values is available.SavedModel
: includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint)
Checkpoint typically has two files (ref: stackoverflow)
- index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized
BundleEntryProto
. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc. - data file: it is
TensorBundle
collection, save the values of all variables.
To save/restore a checkpoint
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
To inspect a checkpoint:
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
reader.get_tensor(key)
3. Jax
Jax is basically a stack of compilers, it is based on the transformation and compilation of functional pure programs.
Here is a great blog comparing static/dynamic autodifferentiation library
3.1. lax
lax is a wrapper over XLA
3.1.1. Arithmetic
3.1.2. Control Flow
See this section
3.2. jaxpr
3.3. jnp
similar to the numpy syntax
3.3.1. packing and segments
sometimes we want to do packing to efficiently use TPU.
Here is a example illustrating packing from tensor2tensor library. It is concatenating sentence 'horizontally' and tracking sentence by a segment id.
It can be combined with batching as well (batching is to concatenate padded sequence 'vertically')
Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
The output example is:
{
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
"inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
"inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
"targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
"targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}
3.3.2. paddings and masking
some paddings related code snippets
# create padding from ids
# seq = [1,2,3,0,0]
# padding = [0, 0, 0, 1, 1]
jnp.equal(seq, 0)
# apply paddings
masked_features = features * (1.0-paddings[:, :, jnp.newaxis])
# padding to length
seq_lengths = jnp.sum(1.0 - paddings, axis=-1)
# length to paddings
# DeviceArray([[ True, True, True, True],
# [ True, True, False, False],
# [ True, False, False, False]], dtype=bool)
lengths = jnp.array([4,2,1])
col_ids = jnp.arange(4)[jnp.newaxis, :]
col_ids < length[:, jnp.newaxis]
# Look-ahead mask or causal mask
def create_look_ahead_mask(size: int) -> np.ndarray:
mask = np.ones([size, size])
mask[np.tril_indices(size)] = 0.0
return mask # (seq_len, seq_len)
# (array([0, 1, 1]), array([0, 0, 1]))
np.tril_indices(2)
#array([[0., 1., 1.],
# [0., 0., 1.],
# [0., 0., 0.]])
create_look_ahead_mask(3)
def another_look_ahead_mask(size: int)
# causal mask using col/row id comparison
col_idx = jnp.tile(jnp.arange(size)[jnp.newaxis, :], [size, 1])
row_idx = jnp.tile(jnp.arange(size)[:, jnp.newaxis], [1, size])
mask = (row_idx < col_idx)
return mask
# packed segment masking
# [B,T,1]
fst_segment_ids = segment_ids[:, :, jnp.newaxis]
# [B,1,T]
snd_segment_ids = segment_ids[;, jnp.newaxis, :]
jnp.not_equal(fst_segment_ids, snd_segment_ids)
3.4. Frameworks
3.4.1. flax
This colab covers most oof the basics
3.4.1.1. linen
See this linen documentation
It is designed to do stateless functional transform as the following where \(v\) are variables.
class (flax.linen.Module)
Param vs Variable
Module.param
is a special class of Module.variable
. It is immutable. See this post
p = self.param('param_name', init_fn, shape, dtype)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value
method init/apply
initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. returns a frozendict of params
model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
# init
params = model.init(key2, x)
# apply
y = model.apply(params, x)
# apply with mutable states
y, some_updated_state = model.apply(params, x, mutable=[some_state_name])
3.4.1.2. serialization
usage example
from flax import serialization
# dump
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
# load
serialization.from_bytes(params, bytes_output)
3.4.2. optax
usage example
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)