DeepSpeed 稀疏注意力
基于注意力的深度学习模型(如 Transformer)在捕获输入序列中标记之间的关系方面非常有效,即使跨越很长的距离也是如此。因此,它们被用于文本、图像和基于声音的输入,其中序列长度可以达到数千个标记。然而,尽管注意力模块在捕获长期依赖关系方面非常有效,但实际上,它们对长序列输入的应用受到注意力计算的计算和内存需求的限制,这些需求随着序列长度n
二次增长,O(n^2)
。
为了解决这个限制,DeepSpeed 提供了一套稀疏注意力内核——一项可以通过块稀疏计算将注意力计算的计算和内存需求降低几个数量级的关键技术。该套件不仅缓解了注意力计算的内存瓶颈,而且还高效地执行稀疏计算。它的 API 允许方便地与任何基于 Transformer 的模型集成。除了提供各种稀疏结构外,它还具有处理任何用户定义的块稀疏结构的灵活性。更具体地说,稀疏注意力 (SA) 可以被设计为计算附近标记之间的局部注意力,或者通过使用局部注意力计算的摘要标记来计算全局注意力。此外,SA 还允许随机注意力,或任何局部、全局和随机注意力的组合,如以下图形所示,分别使用蓝色、橙色和绿色块。因此,SA 将内存占用减少到O(wn)
,其中1 < w < n
是一个参数,其值取决于注意力结构。
这个库基于 PyTorch,并通过Triton 平台开发必要的内核;内核不是用 CUDA 编写的,这为未来 CPU/OpenCL/Vulkan 支持敞开了大门。该库是 DeepSpeed 的扩展,可以通过 DeepSpeed 使用,也可以单独使用。DeepSpeed 稀疏注意力内核处理的块稀疏计算在以下图形中分别说明了正向和反向传递。在这些图中,S
代表一个块稀疏矩阵
,D
代表一个密集矩阵
。
要了解有关稀疏性配置的更多信息,以及如何使用这个库,请查看我们的教程,它提供了有关它的详细信息。
性能结果
- 对超过 10 倍长序列的强大功能 在一个预训练实验中,我们以三种设置运行了 BERT 模型:密集、带激活检查点的密集和带激活检查点的稀疏 (SA)。与密集相比,SA 使 BERT base 和 large 的序列长度分别长 10 倍和 16 倍。下图显示了 BERT base 和 large 模型中可运行的最长序列长度;实验是在单个 NVIDIA V100 GPU-32GB 内存上以批次大小 1 执行的。
- 高达 6.3 倍的更快计算 我们使用BERT base/large 和Megatron GPT2 继续了针对不同批次大小和序列长度的预训练实验。在这个实验中,我们让训练继续进行 100 次迭代,并记录最后 30 次迭代的平均时间。与密集相比,SA 减少了总计算量并提高了训练速度:随着序列长度的增加,提升幅度更大,对于 BERT base 高达 6.3 倍,对于 BERT large 为 5.3 倍,对于 GPT2 为 6.1 倍。下图显示了这些结果。
- 更高的准确率 沿着稀疏注意力的路线相关的作品(稀疏 Transformer、Longformer、BigBird)已经证明了与完全注意力相当或更高的准确率。我们的经验与之吻合。除了更低的内存开销和更快的计算速度之外,我们还在生产中观察到 SA 达到更高的准确率和更快的收敛速度的情况。下图说明了基于 BERT 的生产模型在长文档理解(2,048 序列长度)方面的训练准确率。实验在三种设置下进行:从头开始的密集、从头开始的 SA 以及从使用密集(序列长度为 512)的检查点继续训练的 SA。我们已经观察到,对于从头开始的预训练,与密集相比,SA 以更高的准确率更快地收敛。此外,从预训练检查点继续训练的 SA 在时间和准确率方面都表现得更好。
- 与最先进的 Longformer 相比 我们将 SA 与 Longformer 进行了比较,Longformer 是一种最先进的稀疏结构和实现。在我们的实验中,SA 使用
Fixed
稀疏性,两种实现具有可比较的准确率。在系统性能方面,SA 在训练和推理方面都优于 Longformer。- 1.47 倍 更快的 Wikitext103 上的 MLM 预训练执行速度 我们根据笔记本 进行了实验,该笔记本由 Longformer 提供。在这个实验中,我们使用 RoBERTa-base 检查点预训练了一个 MLM 模型。这是在 8 个 V100-SXM2 GPU 上完成的。下表显示了结果的详细信息,其中使用 DeepSpeed 稀疏注意力显示出 1.47 倍的提速。
模型 | 本地窗口大小 | BPC | 训练步骤 | 每次迭代时间 | 时间改进 | 准确率改进 |
---|---|---|---|---|---|---|
RoBERTa 检查点 | 2.5326 | |||||
Longformer | 512 | 2.6535 | 0 | 1.47 | 1.01 | |
稀疏注意力 | 2.6321 | |||||
Longformer | 1.6708 | 3k | 1.6280 | 1.01 | ||
稀疏注意力 | 1.6613 | 1.1059 | ||||
Longformer | 64 | 5.7840 | 0 | 1.31 | 1.46 | |
稀疏注意力 | 3.9737 | |||||
Longformer | 2.0466 | 3k | 1.4855 | 1.09 | ||
稀疏注意力 | 1.8693 | 1.1372 |
- 3.13 倍 更快的 BERT-Base 推理执行速度 通过我们上面描述的长文档理解应用程序,我们还检查了在
2,048
序列长度和批次大小1
上测试 BERT 模型的不同窗口大小的推理时间。在这个实验中,我们注意到将 Bert Attention 替换为 DeepSpeed 稀疏注意力而不是 Longformer 注意力,速度提高了3.13 倍
。下表显示了完整的結果。
本地窗口大小 | 时间改进 |
---|---|
512 | 3.13 |
256 | 2.29 |
128 | 2.16 |
64 | 1.5 |
32 | 1.24 |
16 | 1.23 |
- 灵活处理任何块稀疏结构 DeepSpeed 稀疏注意力套件不针对任何特定的稀疏结构,而是使模型科学家能够以高效的系统支持探索任何块稀疏结构。目前,我们已经添加了流行的稀疏结构,例如
- Fixed(来自 OpenAI 稀疏 Transformer)
- BigBird(来自 Google)
- BSLongformer(Longformer 的块稀疏实现,来自 AI2)
我们还定义了一个模板来具有variable
结构(上图),它可以用来简单地定制任何块稀疏随机/局部/全局注意力模式。除了这个列表之外,用户还可以添加任何其他稀疏结构,如教程 部分所述。