0x504 Distribution
- 1. Replication (Data Parallel)
- 3. Sharding (Tensor Parallel)
- 2. Pipeline (Vertical Model Parallel)
- 4. Federated Learning
- 5. Inference
- 6. Reference
1. Replication (Data Parallel)
1.1. Naive Data Parallel
Each setup is replicated across devices, during training
- each device receives a shard of the batch
- each device computes its own loss/gradient using its own data shard
- gradient are synchronized after each step
This works well as long as model can fit into a single device
1.2. ZeRO
Model (ZeRO) sharding parameter, gradient and optimizer's state across devices
- stage 1: Optimizer State Partitioning
- stage 2: Add Gradient Partitioning
- state 3: Add Parameter Partitioning
According to the paper, a trillion-parameter model with an optimizer like Adam in 16-bit precision requires approximately 16 terabytes (TB) of memory to hold the optimizer states, gradients, and parameters. 16TB divided by 1024 is 16GB, which is well within a reasonable bound for a GPU (e.g., with 32GB of on-device memory)
3. Sharding (Tensor Parallel)
see this blog for some intuition
Check this blog for XLA sharding introduction
This section summarizes Google's XLA sharding frameworks (i.e. GSPMD, GShard, Mesh-tensorflow, DTensor...) for both TF and Jax
3.1. Sharding Completion
See section 3.5 of GSPMD paper
3.2. Implementations
Model (GShard) API annotation extension of XLA, which requires the user to annotate a few critical tensors in the model with partitioning policies. It generates one program (SPMD) for all shards
- replicate(tensor) annotates tensor to be replicated across partitions, and returns the annotated tensor.
- split(tensor, split_dimension, num_partitions) annotates tensor to be partitioned along split_dimension, and returns the annotated tensor
- shard(tensor, device_assignment) generalizes split() to allow partitioning multiple dimensions and specifying the placement of each partition.
SPMD is better than MPMD as MPMD needs to compile code for each partition, which takes nontrivial time.
it was used to implement the Sparsely-Gated Mixture-of-Experts layers where top 2 experts are activated
Model (GSPMD) GSPMD is a generalization from the backend of Gshard, it introduces a new API called mesh_split(tensor, device_mesh, dims_mapping)
Model (Megatron-LM) manually sharding by inserting comm ops into the model. Sharding layout is same as Mesh-Tensorflow's example
3.3. Expert Parallel
Parallelism based on Mixture of Experts
Model (Switch Transformer)
2. Pipeline (Vertical Model Parallel)
2.1. Naive Model Parallel
spread layers across multiple devices vertically. For example, place layer 0-3 to device 0 and layer 4-7 to deice 1
Model (DistBelief)
2.2. Pipeline Parallelism
Naive model parallel has local GPU usage, so we come to the pipeline execution
Model (GPipe)
Gpipe consists of the following two implementation (reference: torchgpipe)
- pipeline Parallelism: GPipe splits a model into multiple partitions and places each partition on a different device to occupy more memory capacity. And it splits a mini-batch into multiple micro-batches to make the partitions work as parallel as possible.
- checkpointing (re-materialization): Checkpointing is applied to each partition to minimize the overall memory consumption by a model. During forward propagation, only the tensors at the boundaries between partitions are remembered. All other intermediate tensors are volatilized, and recomputed during backpropagation when necessary.
4. Federated Learning
Algorithms run in decentralized edge devices using local data samples. Parameter updates are aggregated to servers.
Model (structured updates, sketched updates) Communication efficiency is important. To reduce the uplink cost
- structured updates: update is restricted to be low rank and random mask
- sketched update: full update is computed but then get compressed using subsampling (by averaging the a subsampled set of updates), quantization with random rotations
5. Inference
Model (zero inference)
- deepspeed transformer: GPU only
- heterogeneous inference: GPU + CPU + NVMe