When GPT-4 was released in 2023, rumors suggested it contained over 1.7 trillion parameters. Training such a model requires approximately 25,000 A100 GPUs running for months—a feat that would be impossible without sophisticated distributed training systems. The challenge isn’t merely computational; it’s fundamentally a memory problem. A single 80GB A100 GPU can barely hold a 40B parameter model during training, let alone a trillion-parameter behemoth. This is the story of how systems researchers cracked the memory wall through a decade of innovations in data parallelism, ZeRO, tensor parallelism, and pipeline parallelism.

The Memory Consumption Breakdown

Understanding distributed training begins with understanding where GPU memory goes. During training, memory is consumed by four primary components:

Model Parameters: The weights of the neural network. For a model with $\Psi$ parameters stored in BF16 (2 bytes each), this requires $2\Psi$ bytes.

Gradients: The derivatives computed during backpropagation, also stored in BF16, adding another $2\Psi$ bytes.

Optimizer States: This is where memory consumption explodes. The Adam optimizer maintains two additional tensors per parameter—the first moment (momentum) and second moment (variance) estimates—both typically stored in FP32 for numerical stability. This adds $8\Psi$ bytes (4 bytes × 2 states × $\Psi$ parameters).

Activations: Intermediate outputs saved during the forward pass for gradient computation. For a transformer with $L$ layers, hidden dimension $h$, sequence length $s$, and batch size $b$, activation memory scales approximately as $34bsh + 5bs^2h$ per layer.

For a 7B parameter model with Adam optimizer:

  • Parameters: 14 GB
  • Gradients: 14 GB
  • Optimizer states: 56 GB
  • Total model states: 84 GB (before activations)

This explains why a 7B model can’t train on a single 80GB GPU—the optimizer states alone exceed available memory. Now extrapolate to 175B parameters: model states alone require 2.1 TB.

Data Parallelism: The Starting Point

Data parallelism (DP) is conceptually simple: replicate the model across $N_d$ GPUs, split the batch into $N_d$ mini-batches, compute gradients independently, then synchronize via AllReduce.

# Pseudocode for data parallelism
for each GPU i:
    local_batch = batch[i::N_d]
    loss = model(local_batch)
    gradients = backward(loss)
    
# Synchronize gradients across all GPUs
gradients = all_reduce(gradients, average=True)
optimizer.step(gradients)

The problem? Every GPU holds a complete copy of parameters, gradients, and optimizer states. Memory per device is constant regardless of $N_d$. A 100B model requires identical 1.2 TB memory on every GPU—impossible with any single device.

ZeRO: Eliminating Redundancy Across Data Parallel Processes

Microsoft’s Zero Redundancy Optimizer (ZeRO), introduced in 2019, recognized a critical insight: data parallelism replicates identical optimizer states across every GPU. Why not partition them instead?

Stage 1: Optimizer State Partitioning (Pos)

ZeRO-1 partitions optimizer states across $N_d$ GPUs. Each GPU stores only $\frac{1}{N_d}$ of the optimizer states while maintaining full copies of parameters and gradients.

Memory per device: $2\Psi + 2\Psi + \frac{K\Psi}{N_d}$

Where $K=12$ for Adam (8 bytes for FP32 momentum + 4 bytes for FP32 variance). With $N_d=64$, optimizer states shrink from $12\Psi$ to $\frac{12\Psi}{64} = 0.1875\Psi$ per GPU.

Result: 4× memory reduction, identical communication volume to standard DP.

Stage 2: Adding Gradient Partitioning (Pos+g)

ZeRO-2 extends partitioning to gradients. Since each GPU only updates its portion of optimizer states, it only needs the corresponding gradients. During backpropagation, gradients are immediately reduced-scattered instead of all-reduced.

Memory per device: $2\Psi + \frac{2\Psi}{N_d} + \frac{K\Psi}{N_d}$

Result: 8× memory reduction, still no additional communication overhead.

Stage 3: Full Parameter Partitioning (Pos+g+p)

ZeRO-3 completes the transformation by partitioning parameters themselves. Each GPU holds only $\frac{1}{N_d}$ of parameters. Before forward or backward passes, GPUs perform AllGather to collect needed parameters, then immediately discard them after use.

Memory per device: $\frac{2\Psi + 2\Psi + K\Psi}{N_d} = \frac{16\Psi}{N_d}$

Result: Linear memory reduction with $N_d$. A trillion-parameter model ($\Psi=10^{12}$) requires 16 TB total memory; with 1024 GPUs, each needs only 16 GB.

The trade-off? Communication volume increases by 50% due to additional AllGather operations during forward and backward passes.

# ZeRO-3 forward pass pseudocode
def forward(x):
    for layer in model:
        # AllGather parameters for this layer
        params = all_gather(sharded_params[layer])
        
        # Compute forward pass
        x = layer(x, params)
        
        # Immediately free parameters
        free(params)
    
    return x

PyTorch FSDP: ZeRO for the Masses

PyTorch’s Fully Sharded Data Parallel (FSDP), released in 2022, brings ZeRO-style sharding to the broader community. FSDP wraps modules and automatically handles parameter sharding, gathering, and freeing.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = TransformerModel(...)
model = FSDP(
    model,
    sharding_strategy="FULL_SHARD",  # Equivalent to ZeRO-3
    device_id=torch.cuda.current_device()
)

FSDP introduced several innovations beyond ZeRO:

FlatParameter Design: Parameters within an FSDP unit are flattened into a contiguous tensor, enabling efficient single-memory operations instead of many small transfers.

Compute-Communication Overlap: While computing forward pass for layer $i+1$, FSDP asynchronously AllGathers parameters for layer $i+2$, hiding communication latency.

Mixed Precision Sharding: FSDP can shard parameters in FP32 while computing in BF16, balancing memory efficiency with numerical stability.

Tensor Parallelism: Intra-Layer Model Parallelism

While ZeRO partitions state across data-parallel processes, tensor parallelism (TP) partitions individual layers across GPUs. NVIDIA’s Megatron-LM introduced the canonical approach for transformers.

Consider a linear layer $Y = XW$ where $X \in \mathbb{R}^{b \times d_{in}}$ and $W \in \mathbb{R}^{d_{in} \times d_{out}}$.

Column Parallel

Partition $W$ along columns: $W = [W_1, W_2, ..., W_N]$ where each $W_i$ has shape $d_{in} \times \frac{d_{out}}{N}$.

Each GPU $i$ computes $Y_i = XW_i$. Results are concatenated: $Y = [Y_1, Y_2, ..., Y_N]$.

No communication needed during forward pass—each GPU produces a slice of the output.

Row Parallel

Partition $W$ along rows: $W = [W_1^T, W_2^T, ..., W_N^T]^T$ where each $W_i$ has shape $\frac{d_{in}}{N} \times d_{out}$.

Each GPU $i$ receives $X_i$ (the $i$-th column split of $X$) and computes $Y_i = X_iW_i$.

Results must be summed via AllReduce: $Y = \sum_{i} Y_i$.

The Megatron Transformer Block

Megatron combines column and row parallel to minimize communication:

Input X
   ↓
[Column Parallel Linear] → No communication
   ↓
[GELU Activation]
   ↓
[Row Parallel Linear] → AllReduce
   ↓
Output Y

For a transformer’s MLP block with two linear layers:

  1. First layer (Column Parallel): Each GPU computes partial output, no communication
  2. Second layer (Row Parallel): Each GPU has partial input, AllReduce at the end

This design requires only one AllReduce per transformer block, making TP communication-efficient within a node where NVLink provides 600 GB/s bandwidth.

class ColumnParallelLinear(nn.Module):
    def forward(self, x):
        # x: [batch, d_in]
        # weight: [d_in, d_out // N]
        out = x @ self.weight  # Local computation
        return out  # No sync needed

class RowParallelLinear(nn.Module):
    def forward(self, x):
        # x: [batch, d_out // N]
        # weight: [d_out // N, d_out]
        out = x @ self.weight
        out = all_reduce(out, op=torch.distributed.ReduceOp.SUM)
        return out

Pipeline Parallelism: Inter-Layer Model Parallelism

Pipeline parallelism (PP) partitions the model’s layers across GPUs. GPU 0 holds layers 1-4, GPU 1 holds layers 5-8, and so on. Data flows sequentially through the pipeline.

The challenge: How do we keep all GPUs busy when there’s sequential dependency between stages?

GPipe: Micro-batching

GPipe divides each batch into $M$ micro-batches. Each micro-batch flows through the pipeline sequentially.

Time:  0  1  2  3  4  5  6  7
GPU 0: F1 F2 F3 F4 -- -- -- --
GPU 1: -- F1 F2 F3 F4 -- -- --
GPU 2: -- -- F1 F2 F3 F4 -- --
GPU 3: -- -- -- F1 F2 F3 F4 --

Where F$i$ = forward pass of micro-batch $i$. Backward passes occur after all forwards complete.

The “bubble” (idle time) is approximately $\frac{P-1}{M}$ of total time, where $P$ is pipeline depth. With $M=8$ micro-batches and $P=4$ stages, bubble overhead is 37.5%.

1F1B: One Forward, One Backward

PipeDream’s 1F1B schedule reduces bubbles by interleaving forward and backward passes:

Time:  0  1  2  3  4  5  6  7
GPU 0: F1 F2 F3 F4 B1 F5 B2 F6
GPU 1: -- F1 F2 F3 B1 F4 B2 F5
GPU 2: -- -- F1 F2 B1 F3 B2 F4
GPU 3: -- -- -- F1 B1 F2 B2 F3

Each GPU alternates between forward and backward passes for different micro-batches, maintaining steady throughput after an initial warmup phase.

PipeFill: Utilizing Bubble Time

Recent research like PipeFill (2024) proposes filling bubbles with auxiliary computations—gradient accumulation, validation, or even training smaller models—improving overall GPU utilization from 70% to 85%.

3D Parallelism: The Production Standard

Modern LLM training combines all three: Data Parallelism × Tensor Parallelism × Pipeline Parallelism.

Consider training a 175B model on 1,024 GPUs:

  • Tensor Parallelism ($t=8$): Each transformer layer split across 8 GPUs within a node
  • Pipeline Parallelism ($p=4$): 32 layers per pipeline stage
  • Data Parallelism ($d=32$): 32 replicas of the entire model
Total GPUs = t × p × d = 8 × 4 × 32 = 1,024

Memory per GPU:
- Parameters: 175B × 2 bytes / (t × p) = 14 GB / 32 = 0.44 GB
- With ZeRO-3: Further reduced by d=32

This is the architecture behind GPT-3, LLaMA, and virtually every frontier model.

Expert Parallelism: Scaling Mixture-of-Experts

Mixture-of-Experts (MoE) models like Mixtral-8×7B or DeepSeek-V3 introduce a fourth dimension: Expert Parallelism (EP).

In MoE, each token is routed to a subset of experts (typically 2 out of 8). Experts are distributed across GPUs:

# Expert parallelism pseudocode
def moe_forward(x, experts, router):
    # Compute routing decisions
    router_output = router(x)  # [batch, seq, num_experts]
    topk_indices = torch.topk(router_output, k=2).indices
    
    # All-to-All communication: send tokens to their expert GPUs
    x_dispatched = all_to_all(x, topk_indices)
    
    # Local expert computation
    expert_outputs = [experts[i](x_dispatched[i]) for i in local_experts]
    
    # All-to-All: route outputs back to original GPU
    x_combined = all_to_all(expert_outputs, reverse=True)
    
    return x_combined

The All-to-All communication pattern—where each GPU sends different data to every other GPU—is EP’s bottleneck. NVIDIA’s H100 with NVLink Switch achieves 900 GB/s All-to-All bandwidth, enabling EP at scale.

Communication Optimization: ZeRO++ and Beyond

ZeRO-3’s communication overhead becomes significant at scale. ZeRO++ (2023) introduces three techniques to reduce communication by 4×:

Quantized Weight Communication (qwZ): AllGather parameters in INT8 instead of BF16, halving bandwidth requirements. A qgAdam optimizer maintains FP32 master weights while communicating INT8.

Hierarchical Parameter Gathering (hpZ): For homogeneous clusters, partition parameters into overlapping groups. Each GPU permanently caches one group, reducing AllGather volume.

Quantized Gradient Communication (qgZ): Apply quantization to gradient reduce-scatter, further cutting backward pass communication.

NCCL, NVIDIA’s collective communication library, employs ring-based AllReduce within nodes (maximizing NVLink bandwidth) and tree-based algorithms across nodes (minimizing latency hops). Understanding these algorithms helps debug distributed training bottlenecks.

Offloading: Breaking the GPU Memory Wall

When even ZeRO-3 can’t fit your model, offloading moves state to CPU memory or NVMe SSDs.

ZeRO-Offload: Moves optimizer states and computation to CPU. The GPU performs forward/backward passes while the CPU updates optimizer states in parallel. This doubles memory capacity but introduces CPU-GPU transfer overhead.

ZeRO-Infinity: Extends offloading to NVMe SSDs. With modern NVMe (7 GB/s), even trillion-parameter models become trainable on modest clusters—theoretically, 16 TB of model states fit on a single SSD array.

The trade-off is throughput: ZeRO-Offload achieves ~30% of GPU-only throughput, while ZeRO-Infinity drops to 10-15%. These are last-resort options when memory, not speed, is the constraint.

Practical Guidelines: Choosing Your Strategy

Model Size GPUs Recommended Strategy
< 1B 1-8 FSDP with ZeRO-2
1-10B 8-64 FSDP with ZeRO-3
10-70B 64-512 3D Parallelism (TP=8, PP=4, DP)
70B-175B 512-2048 3D + ZeRO-3
175B+ 2000+ 3D + ZeRO-3 + ZeRO++
MoE (100B+) 512+ 3D + Expert Parallelism

Key heuristics:

  • Tensor Parallelism: Use within-node only (NVLink required). $t \leq 8$ for 8-GPU nodes.
  • Pipeline Parallelism: Use across nodes. $p$ should divide evenly into number of layers.
  • Sequence Parallelism: For sequences > 8K tokens, add sequence parallelism to handle attention’s $O(n^2)$ memory.
  • Offloading: Only when absolutely necessary; expect 3-10× slowdown.

The Road Ahead

Distributed training continues evolving. FSDP2 in PyTorch 2.x improves async execution. Sequence parallelism now combines with Ring Attention for million-token contexts. And the community co-designs algorithms with hardware—Blackwell’s FP8 training, NVLink 4’s 1.8 TB/s bandwidth, and CXL memory expansion all reshape what’s possible.

The trillion-parameter barrier was broken not by bigger GPUs, but by smarter systems. Understanding these architectures isn’t academic—it’s essential for anyone training the next generation of AI models.