Scaling and Distributed Fine-Tuning (DeepSpeed, FSDP, ZeRO)
Advanced distributed training strategies to scale fine-tuning across multiple GPUs and nodes while managing memory, communication, and fault tolerance.
Content
6.3 ZeRO Partitions and Optimizations
Versions:
Watch & Learn
AI-discovered learning video
Sign in to watch the learning video for this topic.
ZeRO Partitions and Optimizations — The Memory-Slaying Spellbook
"If your GPU memory is a tiny apartment and your model is a hoarder, ZeRO is the minimalist intervention you desperately need." — Probably a very caffeinated researcher
You're already comfortable with the difference between data and model parallelism (we talked about that in 6.2), and you know the landscape of distributed training architectures (shoutout to 6.1). You may also have experimented with quantization and pruning to squeeze models smaller for inference. ZeRO is the missing Tetris master that lets you train massive models affordably by rearranging optimizer state, gradients, and parameters across devices instead of naively duplicating everything per GPU.
Quick refresher (in a sentence)
ZeRO (Zero Redundancy Optimizer) splits the heavy pieces of training state across data-parallel ranks so each GPU stores only a fraction of optimizer states, gradients, and/or parameters — dramatically lowering memory usage and enabling larger effective batch sizes or bigger models.
What are the ZeRO "partitions"? (The three magic things)
ZeRO partitions training state into three logical buckets. Each can be distributed independently:
- Optimizer state — e.g., Adam's moment estimates (m, v). These are usually the largest memory hogs.
- Gradients — the gradient tensors computed during backward pass.
- Parameters — the model weights themselves.
Partitioning any of these removes redundancy across data-parallel replicas. You choose which to partition via ZeRO stage 1/2/3.
ZeRO stages: What they partition and why you care
| Stage | Partitioned state | Memory reduction | Typical use case |
|---|---|---|---|
| Stage 1 | Optimizer state | Medium | Most low-friction win — cheaper than full model parallelism. |
| Stage 2 | Optimizer state + Gradients | Larger | Great when gradients start dominating memory. |
| Stage 3 | Optimizer state + Gradients + Parameters | Max | Enables training truly huge models; pairs well with model parallelism and offloading. |
TL;DR: Stage 1 = partition the expensive optimizer stuff. Stage 2 = add gradients to the party. Stage 3 = partition parameters too, unlocking the largest models.
How it actually works (high-level plumbing)
- Instead of every GPU keeping a full copy of optimizer states/gradients/params, ZeRO slices each tensor across ranks (a.k.a. sharding). Each rank becomes responsible for a subset of the tensors.
- Communication patterns matter:
- All-gather: reconstruct full parameters when you need to compute a forward pass or apply an update that requires the full parameter (Stage 3 often needs this).
- Reduce-scatter: aggregate gradients while avoiding full replication — used to efficiently combine partial gradient contributions.
- Good implementations overlap communication and compute (e.g., reduce-scatter for gradient aggregation overlapping with backward computation), minimizing wall-clock overhead.
Offloading variants
- CPU offload: Move optimizer states or parameters to host memory to reduce GPU memory pressure (useful when GPU RAM < model needs).
- NVMe offload (ZeRO-Infinity): Spill to NVMe when CPU RAM is insufficient. This buys scale but adds IO latency complexity.
Performance optimizations (the things Prof. Perf loves)
- Bucketing & contiguous allocations
- Pack many small tensors into big contiguous buffers to avoid fragmentation and reduce kernel/comm overhead.
- Comm/compute overlap
- Start communication (e.g., reduce-scatter) as soon as partial gradients are ready while later backward ops still compute.
- Fused kernels
- Fuse small ops (e.g., scaling + add) to reduce launch overhead.
- Sparse partition awareness
- If your model has sparse layers, avoid sharding them blindly: some shapes or layers may be better kept replicated.
- Parameter prefetching and lazy all-gather
- Only all-gather a param shard when needed for forward; free it after use. This reduces transient memory spikes.
- Mixed precision + dynamic loss scaling
- Use fp16/bfloat16 to reduce memory and increase throughput. ZeRO plays nicely with AMP but watch numerics in optimizer states.
- Gradient accumulation
- Combine micro-batches locally to reduce the frequency of global synchronization — helpful when communication is a bottleneck.
- Activation checkpointing
- Trade compute for memory by recomputing activations during backward; this pairs elegantly with ZeRO when activations are also large.
Real-world knobs you’ll twiddle (DeepSpeed examples)
Here's a minimal DeepSpeed config snippet for ZeRO Stage 3 with CPU offload and contiguous gradients.
{
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"contiguous_gradients": true,
"offload_param": { "device": "cpu", "pin_memory": true },
"offload_optimizer": { "device": "cpu", "pin_memory": true }
}
}
Play with bucket sizes and offload strategies depending on your NICs, CPU RAM, and NVMe. There's no universal magic number.
Where ZeRO sits relative to model/data parallelism
- ZeRO is primarily a data-parallel memory optimization — it reduces redundancy among data-parallel replicas. This means you can often scale to hundreds of billions of parameters without moving to aggressive model parallelism.
- That said, ZeRO Stage 3 is frequently combined with model parallel techniques (like tensor or pipeline parallelism) to handle very large models efficiently. Think of ZeRO as the glue that makes data-parallel training feasible at scales where naive replication would explode memory.
Question: what if you already quantized or pruned the model? Great — those methods reduce weight sizes and can reduce memory needs even further. ZeRO complements them: quantization shrinks parameter storage/inference cost, ZeRO reduces training-time redundancy.
Pitfalls and troubleshooting
- OOM during all-gather: Happens when temporary buffers spike. Fixes: increase allgather bucket size, enable contiguous buffers, or hybrid offload.
- Comm bottleneck: If network throughput is the limiter, reduce synchronization frequency (gradient accumulation) or upgrade interconnects.
- Numerical instability with fp16: Keep optimizer states in fp32 or use dynamic loss scaling.
- Checkpointing & recovery: Full-model checkpointing with ZeRO Stage 3 requires careful handling because no rank holds the full model. Use library-provided checkpoint helpers.
Final scene: what should you try first?
- Start with ZeRO Stage 1 — easiest win. See immediate memory drop.
- If gradients still dominate, move to Stage 2.
- When you want the biggest model possible on your cluster, go Stage 3, add offload or ZeRO-Infinity if needed, and combine with activation checkpointing.
Key takeaways:
- ZeRO partitions optimizer state, gradients, and parameters to remove redundancy and unlock scale.
- Stage = granularity: more stages = more memory savings but more communication complexity.
- Optimize buffers, overlap comm & compute, and consider offload before throwing hardware at the problem.
Final thought: quantization and pruning are like slimming the dragon; ZeRO is the dragon trainer who teaches it to lie down in your GPU’s tiny courtyard. Use both — and maybe some activation checkpointing yoga — and you’ll be training models that used to feel like mythical beasts.
Now go shard responsibly.
Comments (0)
Please sign in to leave a comment.
No comments yet. Be the first to comment!