在大语言模型(LLM)训练中,显存不足是一个普遍存在的问题。随着模型规模的不断增长,单个 GPU 的显存容量成为了训练大规模模型的主要瓶颈。DeepSpeed ZeRO(Zero Redundancy Optimizer)技术通过创新的数据分片策略,有效解决了这一问题,使得我们能够训练远超单卡显存上限的超大规模模型。
引言
随着大语言模型规模的快速增长,显存需求呈指数级增长。传统的分布式训练方法虽然能够利用多 GPU 进行训练,但每个 GPU 仍然需要存储完整的模型参数、梯度和优化器状态,这严重限制了可训练的模型规模。
ZeRO 技术通过数据并行与内存优化的结合,将模型训练中的大块数据(优化器状态、梯度和模型参数)分散到不同的 GPU 上,而非在每个 GPU 上都完整存储一份,从而显著降低了每个 GPU 的显存需求。
ZeRO 技术原理
传统数据并行的问题
在传统的数据并行训练中,每个 GPU 都需要存储:
- 模型参数:完整的模型权重
- 梯度:完整的梯度信息
- 优化器状态:如 Adam 优化器的动量、方差等状态
对于大规模模型,这些数据占用的显存非常庞大。例如,一个 175B 参数的模型使用 Adam 优化器时,仅优化器状态就需要约 700GB 显存(每个参数需要 4 个 float32 值)。
ZeRO 的核心思想
ZeRO 的核心思想是消除冗余存储,通过分片技术将原本每个 GPU 都需要存储的完整数据分散到多个 GPU 上,实现显存的线性扩展。
关键洞察:
- 在数据并行中,不同 GPU 上的模型参数是相同的
- 梯度在反向传播后需要进行 All-Reduce 操作
- 优化器状态与参数一一对应
基于这些观察,ZeRO 提出了分阶段的内存优化策略。
ZeRO 的三个阶段
ZeRO-Stage 1:优化器状态分片(Optimizer State Sharding)
原理:
将优化器状态分片存储在不同的 GPU 上,每个 GPU 只存储部分优化器状态。
具体做法:
- 假设有 $N$ 个 GPU,模型参数为 $P$
- 将优化器状态分成 $N$ 个分片,每个 GPU 存储 $P/N$ 个参数对应的优化器状态
- 在参数更新时,每个 GPU 只更新自己负责的那部分参数
内存节省:
- 优化器状态内存减少 $N$ 倍
- 对于 Adam 优化器,每个参数需要 4 个 float32 值,节省效果显著
ZeRO-Stage 2:梯度分片(Gradient Sharding)
原理:
在 Stage 1 的基础上,进一步将梯度分片存储。
具体做法:
- 每个 GPU 只计算和存储部分梯度
- 在反向传播结束时,通过 All-Reduce 操作收集完整的梯度
- 然后每个 GPU 只更新自己负责的参数部分
内存节省:
- 梯度内存减少 $N$ 倍
- 与 Stage 1 结合,总内存节省更加显著
ZeRO-Stage 3:参数分片(Parameter Sharding)
原理:
在 Stage 1 和 Stage 2 的基础上,进一步将模型参数分片存储。
具体做法:
- 模型参数也被分片存储在不同的 GPU 上
- 在训练过程中,当需要某个层的所有参数时,通过 All-Gather 操作将所需参数动态地收集到当前 GPU
- 这意味着在任何给定时间点,每个 GPU 上只完整存在模型参数的一部分
内存节省:
- 模型参数内存减少 $N$ 倍
- 实现了最大程度的内存优化
ZeRO 的具体实现
通信模式
ZeRO 使用了两种主要的通信模式:
All-Gather:用于参数收集
- 当需要某个层的完整参数时,从所有 GPU 收集该层的参数分片
- 通信开销:$O(P)$,其中 $P$ 是参数数量
All-Reduce:用于梯度聚合
- 在反向传播后,聚合所有 GPU 上的梯度分片
- 通信开销:$O(P)$
内存管理策略
按需加载机制:
- 参数只在需要时才加载到 GPU 显存
- 使用完毕后立即释放,避免长期占用显存
分片存储策略:
- 优化器状态:静态分片,训练过程中保持不变
- 梯度:动态分片,每次反向传播后重新分配
- 参数:动态分片,根据计算需求动态加载
计算流程
前向传播:
- 通过 All-Gather 收集当前层需要的参数
- 执行前向计算
- 释放不需要的参数
反向传播:
- 通过 All-Gather 收集当前层需要的参数
- 计算梯度
- 将梯度分片存储
- 释放参数
参数更新:
- 通过 All-Reduce 聚合所有梯度分片
- 每个 GPU 更新自己负责的参数部分
- 更新对应的优化器状态
ZeRO 的变体技术
ZeRO-Offload
原理:
对于模型训练中一些对性能不那么敏感,但内存占用大的部分(如优化器状态、甚至梯度),将其从 GPU 显存转移到 CPU 内存或硬盘(NVMe SSD)。
具体做法:
- 优化器状态存储在 CPU 内存中
- 梯度可以存储在 CPU 内存或 NVMe SSD 中
- 在需要时通过 PCIe 总线传输数据
优势:
- 进一步减少 GPU 显存需求
- 能够训练更大的模型
- 成本相对较低
劣势:
- 增加了 CPU-GPU 数据传输开销
- 训练速度可能有所下降
ZeRO-FSDP(Fully Sharded Data Parallelism)
原理:
ZeRO-FSDP 是 ZeRO-Stage 3 的完整实现,实现了优化器状态、梯度和模型参数的全面分片。
特点:
- 最大程度的内存优化
- 支持任意大小的模型训练
- 通信开销相对较高