Skip to content

0x521 Torch

1. API

1.1. Tensor

Some tensor manipulations snippets

Initialization

## int (torch.int)
torch.IntTensor([1,2,3])
torch.randint(low, high, (size1, size2, ...))

## long long (torch.long)
torch.LongTensor([1,2,3])

## float (torch.float)
torch.FloatTensor([1,2,3])
torch.randn(2,3)


## double (torch.double)
torch.DoubleTensor([1,2,3])

Operation

# element-wise operation 
# add, sub, mul, div (+, -, *, /)
a = torch.randn(2,3,4)

# (2,3,4)
(a*a).shape

# mm: no broadcasing matrix multiplication
# strictly has the form: (n,m)x(m,p) = (n,p)

# shape (2,4)
torch.mm(torch.randn(2,3), torch.randn(3,4))

## bmm: batched version of mm

## matmul: lots of cases...
# following examples are from pytorch docs

# when both shape are <=2, it is the normal matmul
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])


# matmul
# when either operand has >3 dim, it will try broadcast into the same shape and then apply the batched multiplication

>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

Masking

Masking is important. Without proper masking, attention will not work. Recall my own training failure...

Length Masking

Translation between mask and length

# length to mask
length = torch.LongTensor([3,5,4])


# tensor([[ True,  True,  True, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True,  True, False]])
mask = torch.arange(5).unsqueeze(0) < length.unsqueeze(1)

# to get reversed mask, use ~mask

# mask to length
length = torch.sum(mask, dim=1)

apply mask

# inplace updates
b = torch.randn(3,5)
mask = tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])

b.masked_fill_(~mask, -np.inf)
#tensor([[ 1.3503, -0.2101,  0.5982,    -inf,    -inf],
#        [ 0.6641, -0.8612,  0.2389, -1.1343,  1.1640],
#        [ 1.1004, -1.2243, -1.1901, -0.5063,    -inf]])

Triangular Masking

for transformer decoder

>>> length=3
>>> torch.tril(torch.ones((length, length))).bool()
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

Padding

To pad, see the following code

# padding
 a = [torch.tensor([1,2,3]), torch.tensor([3,4])]

# tensor([[ 1,  2,  3],
#   [ 3,  4,  0]])
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)

# for more complicated case, use the following snippet
def pad_list(xs, pad_value, max_len=-1):
    """Perform padding for the list of tensors.

    Args:
        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
        pad_value (float): Value for padding.

    Returns:
        Tensor: Padded tensor (B, Tmax, `*`).

    Examples:
        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
        >>> x
        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
        >>> pad_list(x, 0)
        tensor([[1., 1., 1., 1.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.]])

    """
    n_batch = len(xs)

    if max_len == -1:
        max_len = max(x.size(0) for x in xs)

    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)

    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]

    return pad

To unpad

# get length
a = tensor([[ 1,  2,  3],
            [ 3,  4,  0]])

mask = (a!=0)
length = torch.sum(mask, dim=1)

1.2. Modules

C binding implementation is here

Source code is available here

It looks each module object manage the following members:

  • _modules: submodules
  • _parameters: tensor with gradient, these are to be optimizer with optim
  • _buffers: tensor without gradient, it will not be updated by optim but will be stored in state_dict. Examples are mean, var in batchnorm

state_dict retrieves both parameters and buffers

When we construct a new member in init, it did the registration automatically by overwriting the __setattr__:

    # a simplified version of setattr from pytorch
    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

        # register a parameter
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            self.register_parameter(name, value)
        else:

            # register a module
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                modules[name] = value

Those members would be later __getattr__ by checking against __modules, _parameters, _buffers explicitly.

2. Trace

TorchDynamo captures graph structure fx.graph with dynamic python bytecode transformation

3. Transform

See fx doc

4. Compile

TorchInductor compile graph to the machine code by leverating triton and openMP

See this thread

4.1. Lower

There are two IRs from pytorch 2.0

  • Aten IR: higher functional
  • Prims IR: lower

4.2. Compile

5. Distribution

6. Serialization

6.1. ONNX export

See doc

6.2. HLO export

See this tutorial

7. Frameworks

7.2. Transformers