Skip to content

0x504 Distribution

1. Replication (Data Parallel)

1.1. Naive Data Parallel

Each setup is replicated across devices, during training

  • each device receives a shard of the batch
  • each device computes its own loss/gradient using its own data shard
  • gradient are synchronized after each step

This works well as long as model can fit into a single device

1.2. ZeRO

Model (ZeRO) sharding parameter, gradient and optimizer's state across devices

  • stage 1: Optimizer State Partitioning
  • stage 2: Add Gradient Partitioning
  • state 3: Add Parameter Partitioning

According to the paper, a trillion-parameter model with an optimizer like Adam in 16-bit precision requires approximately 16 terabytes (TB) of memory to hold the optimizer states, gradients, and parameters. 16TB divided by 1024 is 16GB, which is well within a reasonable bound for a GPU (e.g., with 32GB of on-device memory)

zero

3. Sharding (Tensor Parallel)

see this blog for some intuition

Check this blog for XLA sharding introduction

This section summarizes Google's XLA sharding frameworks (i.e. GSPMD, GShard, Mesh-tensorflow, DTensor...) for both TF and Jax

3.1. Sharding Completion

See section 3.5 of GSPMD paper

3.2. Implementations

Model (GShard) API annotation extension of XLA, which requires the user to annotate a few critical tensors in the model with partitioning policies. It generates one program (SPMD) for all shards

  • replicate(tensor) annotates tensor to be replicated across partitions, and returns the annotated tensor.
  • split(tensor, split_dimension, num_partitions) annotates tensor to be partitioned along split_dimension, and returns the annotated tensor
  • shard(tensor, device_assignment) generalizes split() to allow partitioning multiple dimensions and specifying the placement of each partition.

SPMD is better than MPMD as MPMD needs to compile code for each partition, which takes nontrivial time.

gshard

it was used to implement the Sparsely-Gated Mixture-of-Experts layers where top 2 experts are activated

moe

Model (GSPMD) GSPMD is a generalization from the backend of Gshard, it introduces a new API called mesh_split(tensor, device_mesh, dims_mapping)

gspmd

Model (Megatron-LM) manually sharding by inserting comm ops into the model. Sharding layout is same as Mesh-Tensorflow's example

3.3. Expert Parallel

Parallelism based on Mixture of Experts

Model (Switch Transformer)

2. Pipeline (Vertical Model Parallel)

2.1. Naive Model Parallel

spread layers across multiple devices vertically. For example, place layer 0-3 to device 0 and layer 4-7 to deice 1

Model (DistBelief)

2.2. Pipeline Parallelism

Naive model parallel has local GPU usage, so we come to the pipeline execution

Model (GPipe)

pipeline

Gpipe consists of the following two implementation (reference: torchgpipe)

  • pipeline Parallelism: GPipe splits a model into multiple partitions and places each partition on a different device to occupy more memory capacity. And it splits a mini-batch into multiple micro-batches to make the partitions work as parallel as possible.
  • checkpointing (re-materialization): Checkpointing is applied to each partition to minimize the overall memory consumption by a model. During forward propagation, only the tensors at the boundaries between partitions are remembered. All other intermediate tensors are volatilized, and recomputed during backpropagation when necessary.

4. Federated Learning

Algorithms run in decentralized edge devices using local data samples. Parameter updates are aggregated to servers.

federated learning

Model (structured updates, sketched updates) Communication efficiency is important. To reduce the uplink cost

  • structured updates: update is restricted to be low rank and random mask
  • sketched update: full update is computed but then get compressed using subsampling (by averaging the a subsampled set of updates), quantization with random rotations

5. Inference

Model (zero inference)

  • deepspeed transformer: GPU only
  • heterogeneous inference: GPU + CPU + NVMe

6. Reference