DeepSpeed 的通用检查点:实用指南

DeepSpeed 通用检查点功能是一个强大的工具,能够以高效灵活的方式保存和加载模型检查点,从而在不同的模型架构、并行技术和训练配置之间实现无缝的模型训练延续和微调。本教程专为初学者和有经验的用户量身定制,提供了如何在 DeepSpeed 驱动的应用程序中利用通用检查点的分步指南。本教程将引导您完成创建 ZeRO 检查点、将其转换为通用格式以及使用这些通用检查点恢复训练的过程。这种方法对于利用预训练模型和促进跨不同设置的无缝模型训练至关重要。

通用检查点简介

DeepSpeed 中的通用检查点抽象了保存和加载模型状态、优化器状态和训练调度器状态的复杂性。此功能旨在通过最少的配置即可开箱即用,支持各种模型大小和类型,从小型模型到跨多个 GPU 和其他加速器训练的具有不同并行拓扑的大型分布式模型。

先决条件

在开始之前,请确保您具备以下条件

  • 已安装 DeepSpeed,可以通过 pip install deepspeed 进行安装。
  • 一个利用 DeepSpeed 进行分布式训练的模型训练脚本。

如何使用 DeepSpeed 通用检查点

请按照以下三个简单步骤操作

步骤 1:创建 ZeRO 检查点

利用 DeepSpeed 通用检查点的第一步是创建 ZeRO 检查点。ZeRO(零冗余优化器)是 DeepSpeed 中一种内存优化技术,可以高效训练大型模型。要创建 ZeRO 检查点,您需要

  • 使用 ZeRO 优化器通过 DeepSpeed 初始化模型。
  • 将模型训练到所需状态(迭代次数)。
  • 使用 DeepSpeed 的检查点功能保存检查点。

步骤 2:将 ZeRO 检查点转换为通用格式

拥有 ZeRO 检查点后,下一步是将其转换为通用格式。此格式设计为灵活,并与不同的模型架构和 DeepSpeed 配置兼容。要转换检查点

  • 使用 DeepSpeed 提供的 ds_to_universal.py 脚本。
  • 指定 ZeRO 检查点的路径以及通用检查点的所需输出路径。
python ds_to_universal.py --input_folder /path/to/zero/checkpoint --output_folder /path/to/universal/checkpoint

此脚本将处理 ZeRO 检查点并生成通用格式的新检查点。传递 --help 标志以查看其他选项。

步骤 3:使用通用检查点恢复训练

通用检查点准备就绪后,您现在可以恢复训练,可能使用不同的并行拓扑或训练配置。为此,请在 DeepSpeed 配置(json)文件中添加 --universal-checkpoint

总结

DeepSpeed 通用检查点简化了模型状态的管理,使得在不同训练会话和并行技术之间保存、加载和传输模型状态变得更加容易。通过遵循本教程中概述的步骤,您可以将通用检查点集成到您的 DeepSpeed 应用程序中,从而增强您的模型训练和开发工作流程。

有关更详细的示例和高级配置,请参阅 Megatron-DeepSpeed 示例

有关 DeepSpeed 通用检查点的深入技术细节,请参阅arXiv 手稿博客

祝您训练愉快!

更新日期: