0x425 Implementation
1. Primitives
1.1. ISA
1.1.1. PTX
GPU primitives, will document at architecture page for more details
1.2. IR
IR is typically represented with MLIR.
MLIR (Multi-Level Intermediate Representation) defines multiple dialects and progressively convert (lowering) towards machine code.
Relevant links
1.2.2. XLA HLO
HLO has limited orthogonal well-defined ops (< 100), can be lowered from tensorflow graph. Semantic of HLO is defined in the XLA semantics page
1.3. Domain Specific Language / Compiler
Check this repo
1.3.1. CUDA
CUDA implementations related to deep learning
1.3.2. XLA
XLA (Accelerated Linear Algebra) is a domain-specific compiler.
Originally, each operation in the computation graph has a predefined and precompiled kernel that the executor dispatches to
XLA is a compiler compiling the computation graph into a sequence of computation kernels specifically for the model. In particular, it can fuse multiple operations into one operation, therefore improve in both excution speed and memory usage
It looks like the LLVM design, which has front/back separation.
-
The frontend of XLA produces the HLO (High Level Operation, something like the compiler IR) defined here. This IR is hardware independent and can be optimized.
-
The backend of XLA produces the target-dependent hardware code. Actually, the backend depends on LLVM. They emit LLVM IR necessary to represent the XLA HLO computation, and then invoke LLVM to emit native code from this LLVM IR. For example, CPU backend implementation for XLA can be found here
1.3.3. Triton
Frameworks
Tensorflow
Execution
There are a few runtime for tensorflow
executor
default runtime and served as a reference runtime. It is essentially an interpreter for graphs.
It does some transformation/optimization with Grappler (at the graph level, not IR level) and executed with pre-defined ops kernels.
the execution path is (probably)
tf.graph -> kernels
TFRT
a new runtime which lower graph into BEF format (also defined with MLIR), then executed by BEF executor
The execution path is
tf.graph -> BEF -> kernels
XLA
The execution path is
tf.graph -> XLA-HLO -> LLVM IR -> kernel
2. Automatic Differentiation
2.1. Forward Reverse Accumulation
Check this video
To follow the video, we are using numerator notation here:
Let \(F: R^n \to R = D \cdot C \cdot B \cdot A = D(C(B(A(x))))\)
and
The Jacobian \(F'(x) \in R^{1 \times n}\)
There are many ways to order the multiplication, typically,
2.1.1. Forward Accumulation
This is push-forward computing
It computes \(\frac{\partial b}{\partial x}\) everytime
It is convenient to be combine with Jacobian-vector product (JVP)
To obtain a full Jacobian, we can repeat JVP for every column
2.1.2. Reverse Accumulation
This is pull-back computing
It computes \(\frac{\partial y}{\partial b}\) everytime
Similarly, it is convenient to be combined with vector-Jacobian product (VJP)
To build full Jacobian we build one row at a time:
In neural network, where Jacobian is a 1 row vector, we prefer the reverse accumulation
2.2. autograd implementation
In neural network, where Jacobian is a 1 row vector, we prefer the reverse accumulation
Comparison of forward vs reverse:
- reverse-mode requires memory cost, which scales like depth of program
- forward-mode requires \(n\) calls
2.3. Dual Numbers
augmenting real \(a\) to a tuple \(a + b\epsilon\) where \(a, b\) are real and \(\epsilon^2 = 0\) which gives the primitives such as
2.3.1. Jax implementation
Jax is based on the dual number approach
- forward-mode is implemented using the dual number approach
- reverse-mode using forward-mode and transpose. paper
3. Efficiency
3.1. Architecture Search
Check this blog
Gradient
Model (Gradient Checkpoint)
- part of the forward memory are wiped out to save memory usage
- those forward weights will be recomputed when necessary during backward
- check the gif here
See here for pytorch's implementation
3.2. Quantization
3.2.1. Post-training Quantization
From Tensorflow website, Post-training quantization is a conversion technique that can reduce model size while also improving CPU and hardware accelerator latency, with little degradation in model accuracy
- simplest form: convert weight to 8bit precision. At the inference time, convert 8bit back to float point and perform float inference
- dynamic range quantization: activations are quantized to 8 bit and computation are done with 8bit precision
- full integer quantization: everything is quantized to integer. A calibration process is needed to estimate the range of float (min, max). Therefore, a representative dataset is needed.
3.2.2. Quantization-Aware Training
Pro is achieve higher accuracy, Cons are required training pipline, labeled data and hyperparameter tuning.
3.3. Pruning
Model (LOTTERY TICKET HYPOTHESIS) unpruned connection’s value is then reset to its initialization, then retraining
3.4. Distillation
3.5. Emsemble
Model (model soup) averaging the weights of multiple models finetuned with different hyperparameter configurations often improves accuracy and robustness
greedy soups, where models are sequentially added to the soup if they improve accuracy on held-out data, outperforms uniform averaging.
4. Distribution
4.1. Data Parallelism
Model (ZeRO) sharding parameter, gradient and optimizer's state across devices
- stage 1: Optimizer State Partitioning
- stage 2: Add Gradient Partitioning
- state 3: Add Parameter Partitioning
According to the paper, a trillion-parameter model with an optimizer like Adam in 16-bit precision requires approximately 16 terabytes (TB) of memory to hold the optimizer states, gradients, and parameters. 16TB divided by 1024 is 16GB, which is well within a reasonable bound for a GPU (e.g., with 32GB of on-device memory)
4.2. Model Parallelism
Model (DistBelief)
4.2.1. Tensor Parallelism
Model (GShard) API annotation extension of XLA, which requires the user to annotate a few critical tensors in the model with partitioning policies. It generates one program (SPMD) for all shards
- replicate(tensor) annotates tensor to be replicated across partitions, and returns the annotated tensor.
- split(tensor, split_dimension, num_partitions) annotates tensor to be partitioned along split_dimension, and returns the annotated tensor
- shard(tensor, device_assignment) generalizes split() to allow partitioning multiple dimensions and specifying the placement of each partition.
it was used to implement the Sparsely-Gated Mixture-of-Experts layers where top 2 experts are activated
Model (GSPMD) GSPMD is a generalization from the backend of Gshard, it introduces a new API called mesh_split(tensor, device_mesh, dims_mapping)
4.2.2. Pipeline Parallelism
Model (GPipe)
Gpipe consists of the following two implementation (reference: torchgpipe)
- pipeline Parallelism: GPipe splits a model into multiple partitions and places each partition on a different device to occupy more memory capacity. And it splits a mini-batch into multiple micro-batches to make the partitions work as parallel as possible.
- checkpointing (re-materialization): Checkpointing is applied to each partition to minimize the overall memory consumption by a model. During forward propagation, only the tensors at the boundaries between partitions are remembered. All other intermediate tensors are volatilized, and recomputed during backpropagation when necessary.
4.3. Federated Learning
Algorithms run in decentralized edge devices using local data samples. Parameter updates are aggregated to servers.
Model (structured updates, sketched updates) Communication efficiency is important. To reduce the uplink cost
- structured updates: update is restricted to be low rank and random mask
- sketched update: full update is computed but then get compressed using subsampling (by averaging the a subsampled set of updates), quantization with random rotations
4.4. Inference
Model (zero inference)
- deepspeed transformer: GPU only
- heterogeneous inference: GPU + CPU + NVMe