0x521 Torch
1. API
1.1. Tensor
Some tensor manipulations snippets
Initialization
## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))
## long long (torch.long)
torch.LongTensor([1,2,3])
## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)
## double (torch.double)
torch.DoubleTensor([1,2,3])
Operation
# element-wise operation
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)
# (2,3,4)
(a*a).shape
# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)
# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))
## bmm: batched version of mm
## matmul: lots of cases...
# following examples are from pytorch docs
# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

Masking
Masking is important. Without proper masking, attention will not work. Recall my own training failure...
Length Masking
Translation between mask and length
# length to mask
length = torch.LongTensor([3,5,4])
# tensor([[ True, True, True, False, False],
# [ True, True, True, True, True],
# [ True, True, True, True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)
# to get reversed mask, use ~mask
# mask to length
length = torch.sum(mask, dim=1)
apply mask
# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True, True, True, False, False],
[ True, True, True, True, True],
[ True, True, True, True, False]])
b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101, 0.5982, -inf, -inf],
# [ 0.6641, -0.8612, 0.2389, -1.1343, 1.1640],
# [ 1.1004, -1.2243, -1.1901, -0.5063, -inf]])
Triangular Masking
for transformer decoder
>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
[ True, True, False],
[ True, True, True]])
Padding
To pad, see the following code
# padding
a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
# tensor([[ 1, 2, 3],
# [ 3, 4, 0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
if max_len == -1:
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
To unpad
# get length
a = tensor([[ 1, 2, 3],
[ 3, 4, 0]])
mask = (a!=0)
length = torch.sum(mask, dim=1)
1.2. Modules
C binding implementation is here
Source code is available here
It looks each module object manage the following members:
_modules
: submodules_parameters
: tensor with gradient, these are to be optimizer with optim_buffers
: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm
state_dict
retrieves both parameters
and buffers
When we construct a new member in init, it did the registration automatically by overwriting the __setattr__
:
# a simplified version of setattr from pytorch
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
# register a parameter
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
self.register_parameter(name, value)
else:
# register a module
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
modules[name] = value
Those members would be later __getattr__
by checking against __modules
, _parameters
, _buffers
explicitly.
2. Trace
TorchDynamo captures graph structure fx.graph
with dynamic python bytecode transformation
3. Transform
See fx doc
4. Compile
TorchInductor compile graph to the machine code by leverating triton and openMP
See this thread
4.1. Lower
There are two IRs from pytorch 2.0
- Aten IR: higher functional
- Prims IR: lower
4.2. Compile
5. Distribution
6. Serialization
6.1. ONNX export
See doc
6.2. HLO export
See this tutorial