0%

LLM 训练:ZeRO 技术详解

在大语言模型(LLM)训练中,显存不足是一个普遍存在的问题。随着模型规模的不断增长,单个 GPU 的显存容量成为了训练大规模模型的主要瓶颈。DeepSpeed ZeRO(Zero Redundancy Optimizer)技术通过创新的数据分片策略,有效解决了这一问题,使得我们能够训练远超单卡显存上限的超大规模模型。

引言

随着大语言模型规模的快速增长,显存需求呈指数级增长。传统的分布式训练方法虽然能够利用多 GPU 进行训练,但每个 GPU 仍然需要存储完整的模型参数、梯度和优化器状态,这严重限制了可训练的模型规模。

ZeRO 技术通过数据并行内存优化的结合,将模型训练中的大块数据(优化器状态、梯度和模型参数)分散到不同的 GPU 上,而非在每个 GPU 上都完整存储一份,从而显著降低了每个 GPU 的显存需求。

ZeRO 技术原理

传统数据并行的问题

在传统的数据并行训练中,每个 GPU 都需要存储:

  1. 模型参数:完整的模型权重
  2. 梯度:完整的梯度信息
  3. 优化器状态:如 Adam 优化器的动量、方差等状态

对于大规模模型,这些数据占用的显存非常庞大。例如,一个 175B 参数的模型使用 Adam 优化器时,仅优化器状态就需要约 700GB 显存(每个参数需要 4 个 float32 值)。

ZeRO 的核心思想

ZeRO 的核心思想是消除冗余存储,通过分片技术将原本每个 GPU 都需要存储的完整数据分散到多个 GPU 上,实现显存的线性扩展。

关键洞察

  • 在数据并行中,不同 GPU 上的模型参数是相同的
  • 梯度在反向传播后需要进行 All-Reduce 操作
  • 优化器状态与参数一一对应

基于这些观察,ZeRO 提出了分阶段的内存优化策略。

ZeRO 的三个阶段

ZeRO-Stage 1:优化器状态分片(Optimizer State Sharding)

原理
将优化器状态分片存储在不同的 GPU 上,每个 GPU 只存储部分优化器状态。

具体做法

  • 假设有 $N$ 个 GPU,模型参数为 $P$
  • 将优化器状态分成 $N$ 个分片,每个 GPU 存储 $P/N$ 个参数对应的优化器状态
  • 在参数更新时,每个 GPU 只更新自己负责的那部分参数

内存节省

  • 优化器状态内存减少 $N$ 倍
  • 对于 Adam 优化器,每个参数需要 4 个 float32 值,节省效果显著

ZeRO-Stage 2:梯度分片(Gradient Sharding)

原理
在 Stage 1 的基础上,进一步将梯度分片存储。

具体做法

  • 每个 GPU 只计算和存储部分梯度
  • 在反向传播结束时,通过 All-Reduce 操作收集完整的梯度
  • 然后每个 GPU 只更新自己负责的参数部分

内存节省

  • 梯度内存减少 $N$ 倍
  • 与 Stage 1 结合,总内存节省更加显著

ZeRO-Stage 3:参数分片(Parameter Sharding)

原理
在 Stage 1 和 Stage 2 的基础上,进一步将模型参数分片存储。

具体做法

  • 模型参数也被分片存储在不同的 GPU 上
  • 在训练过程中,当需要某个层的所有参数时,通过 All-Gather 操作将所需参数动态地收集到当前 GPU
  • 这意味着在任何给定时间点,每个 GPU 上只完整存在模型参数的一部分

内存节省

  • 模型参数内存减少 $N$ 倍
  • 实现了最大程度的内存优化

ZeRO 的具体实现

通信模式

ZeRO 使用了两种主要的通信模式:

  1. All-Gather:用于参数收集

    • 当需要某个层的完整参数时,从所有 GPU 收集该层的参数分片
    • 通信开销:$O(P)$,其中 $P$ 是参数数量
  2. All-Reduce:用于梯度聚合

    • 在反向传播后,聚合所有 GPU 上的梯度分片
    • 通信开销:$O(P)$

内存管理策略

按需加载机制

  • 参数只在需要时才加载到 GPU 显存
  • 使用完毕后立即释放,避免长期占用显存

分片存储策略

  • 优化器状态:静态分片,训练过程中保持不变
  • 梯度:动态分片,每次反向传播后重新分配
  • 参数:动态分片,根据计算需求动态加载

计算流程

前向传播

  1. 通过 All-Gather 收集当前层需要的参数
  2. 执行前向计算
  3. 释放不需要的参数

反向传播

  1. 通过 All-Gather 收集当前层需要的参数
  2. 计算梯度
  3. 将梯度分片存储
  4. 释放参数

参数更新

  1. 通过 All-Reduce 聚合所有梯度分片
  2. 每个 GPU 更新自己负责的参数部分
  3. 更新对应的优化器状态

ZeRO 的变体技术

ZeRO-Offload

原理
对于模型训练中一些对性能不那么敏感,但内存占用大的部分(如优化器状态、甚至梯度),将其从 GPU 显存转移到 CPU 内存或硬盘(NVMe SSD)。

具体做法

  • 优化器状态存储在 CPU 内存中
  • 梯度可以存储在 CPU 内存或 NVMe SSD 中
  • 在需要时通过 PCIe 总线传输数据

优势

  • 进一步减少 GPU 显存需求
  • 能够训练更大的模型
  • 成本相对较低

劣势

  • 增加了 CPU-GPU 数据传输开销
  • 训练速度可能有所下降

ZeRO-FSDP(Fully Sharded Data Parallelism)

原理
ZeRO-FSDP 是 ZeRO-Stage 3 的完整实现,实现了优化器状态、梯度和模型参数的全面分片。

特点

  • 最大程度的内存优化
  • 支持任意大小的模型训练
  • 通信开销相对较高