Parallelism Concepts in Training (WIP)

An introduction to the conceptual foundations of parallelism in modern large-scale deep learning

Let's learn it from the ground up: what we parallelize, why, how communication works, and how these strategies interplay inside frameworks like PyTorch Distributed, Megatron, and DeepSpeed.

We'll move from the physical level (devices, ranks, process groups) to the algorithmic level (data, model, pipeline, context, expert, etc.), with math and visuals.


1. Why Parallelism Exists in Deep Learning

There are two main bottlenecks in deep learning training:

Resource Bottleneck Examples
Memory Parameters, activations, optimizer states exceed a single GPU's memory (e.g. 100B+ parameters). Transformer layers, attention KV cache
Compute A single GPU cannot finish the training in reasonable time. Large batch training, long sequences

To overcome both, we split the work among multiple devices (GPUs or nodes), but the split must:

  1. Preserve correctness: as if trained on a single GPU.
  2. Maximize throughput: avoid idle GPUs or redundant work.
  3. Minimize communication: since communication is expensive.

That is the essence of parallelism.

TODO: add activation, weight and optimizer state diagrams!

2. The Abstraction: Ranks, Process Groups, and Collectives

    Each process that participates in distributed training has a rank. In PyTorch Distributed (torch.distributed), a rank is simply the unique ID of one process participating in a distributed job. Each process runs (typically) on one GPU.
    All ranks together form a process group. Every rank has its own model copy (for DDP) or shard (for TP/FSDP/PP, etc.).
    Ranks communicate through collective ops (all_reduce, all_gather, broadcast, etc.).
    Think of a rank as β€œwho I am in the distributed world.”

These are implemented using high-performance backends (e.g. NCCL for CUDA, Gloo for CPU).


3. Data Parallelism (DP)

Idea: Each rank holds the entire model but only a subset of the data.

Each forward/backward pass computes gradients locally, then synchronizes via an all-reduce across the data-parallel group.

πŸ”Ή Forward & Backward

Let $\theta$ be model parameters, $D_i$ the local data shard on rank $i$, and $L(\theta, D_i)$ the local loss.

Each rank computes:

$$g_i = \nabla_\theta L(\theta, D_i)$$

Then all ranks perform:

$$g = \frac{1}{N} \sum_{i=1}^N g_i$$

which is an all_reduce operation.

All ranks then update $\theta \leftarrow \theta - \eta g$.

πŸ”Ή Characteristics

  • Memory: O(#params) per GPU.
  • Communication: one all_reduce per backward step.
  • Scales linearly if model fits in GPU memory.
  • PyTorch APIs:
    • torch.nn.parallel.DistributedDataParallel (DDP)
    • torch.distributed.fsdp (FSDP: sharded weights/gradients/optimizer states).

DP is ideal when the model fits on one GPU but you want faster training.


4. Model / Tensor Parallelism (TP)

TODO: add more math here!

Idea: Split the model's parameters (and computations) across ranks.
In this method, we slice weights and activations of an individual layer across ranks.

Instead of each GPU holding the entire weight matrix $W \in \mathbb{R}^{m\times n}$, you can split it across devices:

$$W = [W_1, W_2, \dots, W_p]$$

Each GPU $i$ computes $y_i = x W_i$, then the outputs are combined via all_gather or reduce_scatter.

πŸ”Ή Example: Linear Layer

Input x  (batch Γ— d_model)
Weight W (d_model Γ— 4d_model)
  • Column parallel: each GPU stores a slice of W's columns β†’ outputs concatenated.
  • Row parallel: each GPU stores a slice of W's rows β†’ outputs reduced-summed.

πŸ”Ή Communication Pattern

  • Forward: all_gather outputs from shards.
  • Backward: reduce_scatter gradients to shards.
  • Common collectives: all_reduce, all_gather, reduce_scatter.

πŸ”Ή Characteristics

Metric TP Behavior
Memory Each GPU stores ~1/p of parameters.
Compute Each GPU does ~1/p of matmul.
Communication Per layer (every forward/backward).
Efficiency Increases GPU utilization for massive layers.

PyTorch: torch.distributed.tensor.parallel.parallelize_module using DTensor and DeviceMesh.


5. Pipeline Parallelism (PP)

Idea: Split the layers of a model across ranks. Each rank is responsible for a contiguous block of layers (a pipeline stage).

πŸ”Ή Example

Rank 0 β†’ Layers [1–4]
Rank 1 β†’ Layers [5–8]
Rank 2 β†’ Layers [9–12]

You split each batch into microbatches and "stream" them through the pipeline.

time β†’
 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
 | Rank0 (1–4)  | Rank1 (5–8)  | Rank2 (9–12)|
 β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
 | mb1 forward  |              |              |
 | mb2 forward  | mb1 forward  |              |
 | mb3 forward  | mb2 forward  | mb1 forward  |

This hides idle time ("pipeline bubbles") and keeps GPUs busy.

πŸ”Ή Communication

  • Point-to-point send/recv of activations and gradients between consecutive pipeline ranks.
  • Works best with equal per-stage compute time.

πŸ”Ή Characteristics

Metric PP Behavior
Memory Each GPU holds only a fraction of layers.
Communication Between consecutive pipeline stages.
Latency High startup overhead ("bubbles").
Throughput Improves with more microbatches.

PyTorch: torch.distributed.pipelining, or PiPPy for automated partitioning.


6. Context / Sequence Parallelism (SP, CP)

Used primarily in large context transformers (e.g., 100k–1M tokens).

πŸ”Ή Sequence Parallelism (SP)

Split activations (and lightweight per-token ops like LayerNorm, dropout) along the sequence length dimension.

  • Each GPU processes a slice of tokens.
  • Certain operations require all-reduce across sequence groups (e.g., attention softmax).

πŸ”Ή Context Parallelism (CP)

A more advanced version where inputs and activations are partitioned across the sequence dimension, and attention is computed with ring-based all-gather so each rank sees only the needed portion.

πŸ”Ή Benefits

  • Reduces activation memory linearly with number of ranks in SP/CP group.
  • Enables ultra-long sequences without increasing GPU memory.
  • Communication: ring-based attention exchange (e.g., Megatron's RingAttention).

PyTorch (unstable): Context Parallelism tutorial, built on DTensor + custom attention kernels.


7. Expert Parallelism (EP) β€” Mixture of Experts (MoE)

In MoE layers, only a subset of "experts" (feed-forward networks) are active per token.

πŸ”Ή Structure

Tokens β†’ Router β†’ (Top-k Experts) β†’ Combine outputs

If you have E experts and G GPUs:

  • Each GPU holds one or more experts.
  • Tokens are dynamically routed to the right expert via an all-to-all exchange.

πŸ”Ή Steps

  1. Routing: gate network computes softmax scores β†’ top-k expert IDs.
  2. Dispatch: all-to-all send of token embeddings to corresponding expert ranks.
  3. Local Compute: each GPU processes its local tokens.
  4. Combine: another all-to-all to merge results back.

πŸ”Ή Communication

Two all_to_all ops per MoE layer.

Load balancing loss added to encourage uniform expert usage.

PyTorch: MoE tutorial + blog on Expert Parallelism with DTensor.


8. Additional Forms

Parallelism What is Split Common in
Optimizer State / ZeRO Parameters, gradients, and optimizer states across data-parallel ranks. DeepSpeed, FSDP
Activation Checkpointing Time (compute) vs. memory trade-off, not true parallelism. All large models
Hybrid Parallelism (3D) Combine DP + TP + PP (sometimes + EP + CP). Megatron-LM, GPT-4, LLaMA-3

9. How They Combine (3D or 4D Parallelism)

A "device mesh" formalism expresses these combinations:

$$\text{mesh shape} = (\text{DP}, \text{TP}, \text{PP}, \text{EP}, \text{CP})$$

Each dimension corresponds to a process group.

A single GPU (rank) belongs to one group per dimension.

Example: 384 GPUs

  • DP=4, TP=2, PP=12, EP=4 β†’ (4 Γ— 2 Γ— 12 Γ— 4 = 384)
  • Within each DP group β†’ gradient sync (all_reduce).
  • Within TP group β†’ tensor shards (reduce_scatter/all_gather).
  • Within PP group β†’ pipeline sends/recvs.
  • Within EP group β†’ MoE all-to-all routing.

This n-dimensional sharding view is exactly what PyTorch's DeviceMesh / DTensor abstracts.


10. Underlying Communication Primitives

Collective Function Used in
all_reduce Every rank gets the sum (or mean) of all inputs. DDP, gradient averaging
reduce_scatter Split reduced results among ranks. ZeRO, FSDP
all_gather Gather tensors from all ranks. TP forward pass
broadcast One rank sends tensor to all. Initialization
all_to_all Permute tensors across ranks (generalized scatter/gather). MoE routing
send/recv Point-to-point communication. Pipeline parallelism

These are implemented by NCCL using GPU DMA and NVLink/InfiniBand for efficiency.


11. Comparative Table

Type Split Axis Communication Memory Savings Typical Use
Data Data batch all_reduce grads None Model fits on GPU
Tensor Within layer (weights/acts) all_gather / reduce_scatter Yes Huge layers (β‰₯10B params)
Pipeline Layer depth send/recv activations Yes Very deep networks
Context Sequence length ring all_gather / reduce_scatter Yes Long context transformers
Expert (MoE) Experts all_to_all Sparse compute Scaling params efficiently
ZeRO/FSDP States/params reduce_scatter / all_gather Yes Memory optimization

12. Conceptual Summary

Every parallelism strategy is a trade-off between what you replicate and what you shard.

You replicate… You shard… You gain…
Model weights Data Simplicity, linear speedup (DP)
Model layers or tensors Parameters Fit bigger models (TP/PP)
Activations Sequence Longer contexts (CP/SP)
Experts Tokens Sparse efficiency (EP)
Optimizer states Everything Memory scaling (ZeRO/FSDP)

13. Modern PyTorch Stack Summary

PyTorch has converged all these under DTensor and DeviceMesh:

from torch.distributed.tensor import DeviceMesh, distribute_tensor

mesh = DeviceMesh("cuda", (tp, dp, pp))
sharded = distribute_tensor(tensor, mesh, placements=[Shard(0), Replicate(), Replicate()])

Higher-level integrations:

  • DDP / FSDP β†’ Data & state parallelism
  • TP / PP β†’ torch.distributed.tensor.parallel & torch.distributed.pipelining
  • EP (MoE) β†’ custom DTensor layouts + all_to_all routing
  • CP β†’ experimental ring-based attention kernels

14. Visual Overview

                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                   β”‚        GLOBAL MODEL          β”‚
                   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                          β”‚
         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
         β–Ό                β–Ό                β–Ό
     [DP split]       [Model split]     [Sequence split]
   data batches β†’    layers/tensors β†’   tokens β†’

   all_reduce grads   all_gather acts    ring gather
         β”‚                β”‚                β”‚
         β–Ό                β–Ό                β–Ό
   Combine gradients   Combine outputs   Combine context

🧩 15. How They Work Together (Example: GPT-4-like system)

Dimension Example Value Notes
Data 4 DDP/FSDP for gradient sync across nodes
Tensor 2 Shard linear weights across 2 GPUs
Pipeline 8 Each stage = 1 transformer block
Context 1 Optional sequence sharding
Expert 4 4 experts per MoE layer

Total GPUs = (4Γ—2Γ—8Γ—4 = 256).

Each rank participates in multiple groups simultaneously.


πŸ”š 16. Takeaway

Parallelism β‰  duplication β€” it's a structured partition of the model, data, and computation graph to achieve scale.

  • Data parallelism handles samples.
  • Tensor/model parallelism handles parameters.
  • Pipeline parallelism handles depth.
  • Context/sequence parallelism handles length.
  • Expert parallelism handles sparsity.
  • ZeRO/FSDP handles optimizer state.

Each introduces different synchronization patterns, all unified under PyTorch's distributed collectives and mesh abstractions.