Skip to content

0x425 Implementation

1. Primitives

1.1. ISA

1.1.1. PTX

GPU primitives, will document at architecture page for more details

1.2. IR

IR is typically represented with MLIR.

MLIR (Multi-Level Intermediate Representation) defines multiple dialects and progressively convert (lowering) towards machine code.

Relevant links

1.2.2. XLA HLO

HLO has limited orthogonal well-defined ops (< 100), can be lowered from tensorflow graph. Semantic of HLO is defined in the XLA semantics page

1.3. Domain Specific Language / Compiler

Check this repo

1.3.1. CUDA

CUDA implementations related to deep learning

How to transponse efficiently

1.3.2. XLA

XLA (Accelerated Linear Algebra) is a domain-specific compiler.

Originally, each operation in the computation graph has a predefined and precompiled kernel that the executor dispatches to

XLA is a compiler compiling the computation graph into a sequence of computation kernels specifically for the model. In particular, it can fuse multiple operations into one operation, therefore improve in both excution speed and memory usage


It looks like the LLVM design, which has front/back separation.

  • The frontend of XLA produces the HLO (High Level Operation, something like the compiler IR) defined here. This IR is hardware independent and can be optimized.

  • The backend of XLA produces the target-dependent hardware code. Actually, the backend depends on LLVM. They emit LLVM IR necessary to represent the XLA HLO computation, and then invoke LLVM to emit native code from this LLVM IR. For example, CPU backend implementation for XLA can be found here

1.3.3. Triton





There are a few runtime for tensorflow


default runtime and served as a reference runtime. It is essentially an interpreter for graphs.

It does some transformation/optimization with Grappler (at the graph level, not IR level) and executed with pre-defined ops kernels.

the execution path is (probably)

tf.graph -> kernels


a new runtime which lower graph into BEF format (also defined with MLIR), then executed by BEF executor


The execution path is

tf.graph -> BEF -> kernels


The execution path is

tf.graph -> XLA-HLO -> LLVM IR -> kernel

2. Automatic Differentiation

2.1. Forward Reverse Accumulation

Check this video

To follow the video, we are using numerator notation here:

Let \(F: R^n \to R = D \cdot C \cdot B \cdot A = D(C(B(A(x))))\)


\[y = D(c), c=C(b), b=B(a), a=A(x)\]

The Jacobian \(F'(x) \in R^{1 \times n}\)

\[F'(x) = \frac{\partial y}{\partial c}\frac{\partial c}{\partial b}\frac{\partial b}{\partial a}\frac{\partial a}{\partial x}\]

There are many ways to order the multiplication, typically,

2.1.1. Forward Accumulation

This is push-forward computing

It computes \(\frac{\partial b}{\partial x}\) everytime

\[F'(x) = (\frac{\partial y}{\partial c}(\frac{\partial c}{\partial b}(\frac{\partial b}{\partial a}\frac{\partial a}{\partial x})))\]

It is convenient to be combine with Jacobian-vector product (JVP)

\[F'(x)v = (\frac{\partial y}{\partial c}(\frac{\partial c}{\partial b}(\frac{\partial b}{\partial a}(\frac{\partial a}{\partial x} v))))\]

To obtain a full Jacobian, we can repeat JVP for every column

\[F'(x)v = (\frac{\partial y}{\partial c}(\frac{\partial c}{\partial b}(\frac{\partial b}{\partial a}(\frac{\partial a}{\partial x} \frac{\partial x}{\partial x}))))\]

2.1.2. Reverse Accumulation

This is pull-back computing

It computes \(\frac{\partial y}{\partial b}\) everytime

\[F'(x) = (((\frac{\partial y}{\partial c}\frac{\partial c}{\partial b})\frac{\partial b}{\partial a})\frac{\partial a}{\partial x})\]

Similarly, it is convenient to be combined with vector-Jacobian product (VJP)

\[v^TF'(x) = ((((v^T \frac{\partial y}{\partial c})\frac{\partial c}{\partial b})\frac{\partial b}{\partial a})\frac{\partial a}{\partial x})\]

To build full Jacobian we build one row at a time:

\[F'(x) = ((((\frac{\partial y}{\partial y} \frac{\partial y}{\partial c})\frac{\partial c}{\partial b})\frac{\partial b}{\partial a})\frac{\partial a}{\partial x})\]

In neural network, where Jacobian is a 1 row vector, we prefer the reverse accumulation

2.2. autograd implementation

In neural network, where Jacobian is a 1 row vector, we prefer the reverse accumulation

Comparison of forward vs reverse:

  • reverse-mode requires memory cost, which scales like depth of program
  • forward-mode requires \(n\) calls

2.3. Dual Numbers

augmenting real \(a\) to a tuple \(a + b\epsilon\) where \(a, b\) are real and \(\epsilon^2 = 0\) which gives the primitives such as

\[(x+x'\epsilon) + (y + y' \epsilon) = (x+y) + (x'+y')\epsilon\]
\[(x+x'\epsilon)(y + y'\epsilon) = (xy) + (xy' + yx')\epsilon\]

2.3.1. Jax implementation

Jax is based on the dual number approach

  • forward-mode is implemented using the dual number approach
  • reverse-mode using forward-mode and transpose. paper

3. Efficiency

Check this blog


Model (Gradient Checkpoint)

  • part of the forward memory are wiped out to save memory usage
  • those forward weights will be recomputed when necessary during backward
  • check the gif here

See here for pytorch's implementation

3.2. Quantization

3.2.1. Post-training Quantization

From Tensorflow website, Post-training quantization is a conversion technique that can reduce model size while also improving CPU and hardware accelerator latency, with little degradation in model accuracy

  • simplest form: convert weight to 8bit precision. At the inference time, convert 8bit back to float point and perform float inference
  • dynamic range quantization: activations are quantized to 8 bit and computation are done with 8bit precision
  • full integer quantization: everything is quantized to integer. A calibration process is needed to estimate the range of float (min, max). Therefore, a representative dataset is needed.

3.2.2. Quantization-Aware Training

Pro is achieve higher accuracy, Cons are required training pipline, labeled data and hyperparameter tuning.

3.3. Pruning

Model (LOTTERY TICKET HYPOTHESIS) unpruned connection’s value is then reset to its initialization, then retraining

3.4. Distillation

3.5. Emsemble

Model (model soup) averaging the weights of multiple models finetuned with different hyperparameter configurations often improves accuracy and robustness

greedy soups, where models are sequentially added to the soup if they improve accuracy on held-out data, outperforms uniform averaging.

4. Distribution

4.1. Data Parallelism

Model (ZeRO) sharding parameter, gradient and optimizer's state across devices

  • stage 1: Optimizer State Partitioning
  • stage 2: Add Gradient Partitioning
  • state 3: Add Parameter Partitioning

According to the paper, a trillion-parameter model with an optimizer like Adam in 16-bit precision requires approximately 16 terabytes (TB) of memory to hold the optimizer states, gradients, and parameters. 16TB divided by 1024 is 16GB, which is well within a reasonable bound for a GPU (e.g., with 32GB of on-device memory)


4.2. Model Parallelism

Model (DistBelief)

4.2.1. Tensor Parallelism

Model (GShard) API annotation extension of XLA, which requires the user to annotate a few critical tensors in the model with partitioning policies. It generates one program (SPMD) for all shards

  • replicate(tensor) annotates tensor to be replicated across partitions, and returns the annotated tensor.
  • split(tensor, split_dimension, num_partitions) annotates tensor to be partitioned along split_dimension, and returns the annotated tensor
  • shard(tensor, device_assignment) generalizes split() to allow partitioning multiple dimensions and specifying the placement of each partition.


it was used to implement the Sparsely-Gated Mixture-of-Experts layers where top 2 experts are activated


Model (GSPMD) GSPMD is a generalization from the backend of Gshard, it introduces a new API called mesh_split(tensor, device_mesh, dims_mapping)


4.2.2. Pipeline Parallelism

Model (GPipe)


Gpipe consists of the following two implementation (reference: torchgpipe)

  • pipeline Parallelism: GPipe splits a model into multiple partitions and places each partition on a different device to occupy more memory capacity. And it splits a mini-batch into multiple micro-batches to make the partitions work as parallel as possible.
  • checkpointing (re-materialization): Checkpointing is applied to each partition to minimize the overall memory consumption by a model. During forward propagation, only the tensors at the boundaries between partitions are remembered. All other intermediate tensors are volatilized, and recomputed during backpropagation when necessary.

4.3. Federated Learning

Algorithms run in decentralized edge devices using local data samples. Parameter updates are aggregated to servers.

federated learning

Model (structured updates, sketched updates) Communication efficiency is important. To reduce the uplink cost

  • structured updates: update is restricted to be low rank and random mask
  • sketched update: full update is computed but then get compressed using subsampling (by averaging the a subsampled set of updates), quantization with random rotations

4.4. Inference

Model (zero inference)

  • deepspeed transformer: GPU only
  • heterogeneous inference: GPU + CPU + NVMe

5. Reference