通信日志

在本教程中,我们介绍 DeepSpeed 通信日志记录,并提供其用法的示例。

概述

注意:所有记录通信的调用都是同步的,以提供准确的计时信息。 如果您的模型大量使用异步通信操作,这可能会影响性能。

记录通信调用对于确保网络资源得到充分利用至关重要。 DeepSpeed 通信记录器能够检测和记录在 deepspeed.comm 下启动的所有通信操作。 每个通信操作都可以在完成后的立即打印到控制台(通过 verbose 配置选项),或者可以在训练结束、一个时期结束、在 N 个训练迭代之后等,在客户端代码中使用 deepspeed.comm.log_summary()deepspeed.com.log_summary(show_straggler=True) 调用打印摘要。

用法

DeepSpeed 中的通信日志记录在 deepspeed 配置文件 中配置。 DeepSpeed 将自动记录所有操作 (prof_all) 或用户指定的 operations (prof_ops) 的通信。

配置设置

可以在 DeepSpeed 配置文件 中配置通信日志记录。 可以通过在 DeepSpeed 的配置 JSON 文件中添加以下字段来启用通信日志记录。 有关详细信息,请参阅 通信日志记录

"comms_logger": {
  "enabled": true,
  "verbose": false,
  "prof_all": true,
  "debug": false
}

目前有两种方法可以查看通信日志记录。

  1. 使用 verbose 配置选项打印所有通信操作。 请参阅 详细日志
  2. (推荐) 使用 deepspeed.comm.log_summary() 函数调用打印日志摘要。 请参阅 日志摘要

详细日志

如果选择了 enabled 配置选项,所有通信操作将立即打印到控制台。 此模式旨在进行详细的调试,不推荐大多数用户使用。 以下是 verbose 输出的示例片段。

[2022-06-26 01:39:55,722] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: reduce_scatter_tensor | time (ms): 9.46 | msg size: 678.86 MB | algbw (Gbps): 1204.52  | busbw (Gbps): 1129.23
[2022-06-26 01:39:56,470] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_into_tensor | time (ms): 0.11 | msg size: 6.0 MB | algbw (Gbps): 954.41  | busbw (Gbps): 894.76
[2022-06-26 01:39:56,471] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_into_tensor | time (ms): 0.08 | msg size: 6.0 MB | algbw (Gbps): 1293.47  | busbw (Gbps): 1212.63

对于高级用户,debug 选项会将每个通信操作的调用函数追加到该操作的 log_name。 请参阅 日志摘要,了解在启用 debug 的情况下 deepspeed.comm.log_summary() 调用的示例。

日志摘要

建议用户在训练里程碑(例如,每个时期或 N 次迭代)添加对 deepspeed.comm.log_summary() 的调用。 这使得能够进行高级通信日志记录,而无需筛选来自 verbose 的日志。

添加 DeepSpeed 通信日志摘要的步骤如下。

  1. 使用所需设置修改配置文件。
  2. (可选) 如果您的应用程序包含您希望记录的 torch.distributed 调用,请导入 deepspeed.comm 包,并将 torch.distributed 调用修改为使用 deepspeed.comm(注意:deepspeed.comm 的集体和点对点 API 与 torch.distributed 完全匹配)。
  3. 调用 deepspeed.comm.log_summary

有关示例用法,请参阅以下修改后的 DeepSpeedExamples/cifar 示例。

# Step 2: (Optional) Import deepspeed.comm
import deepspeed.comm as dist

# Note that any communication operations using `import torch.distributed as dist` calls can remain unchanged, and will be automatically logged under deepspeed.comm!
dist.all_reduce(tensor)

for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        pre = time.time()
        inputs, labels = data[0].to(model_engine.local_rank), data[1].to(
            model_engine.local_rank)
        if fp16:
            inputs = inputs.half()
        outputs = model_engine(inputs)
        loss = criterion(outputs, labels)

        model_engine.backward(loss)
        model_engine.step()
        post = time.time()
    # Step 3: Call `deepspeed.comm.log_summary()`
    dist.log_summary()

以下是在使用 ZeRO-3 的 Megatron-DeepSpeed 完成 10 次迭代后截断的 deepspeed.comm.log_summary() 输出示例。

Comm. Op            Message Size        Count               Total Latency(ms)   Avg Latency(ms)     tput_avg (Gbps)     busbw_avg (Gbps)
broadcast
                    2.0 KB              146                 11.12               0.08                0.43                0.41
                    98.25 MB            1                   8317.12             8317.12             0.20                0.19
reduce_scatter_tensor
                    678.86 MB           40                  602.29              9.69                1468.06             1376.31

以下是在相同配置下启用 debug 的情况下对 deepspeed.comm.log_summary 的调用。

Comm. Op            Message Size        Count               Total Latency(ms)   Avg Latency(ms)     tput_avg (Gbps)     busbw_avg (Gbps)
broadcast | [Caller Func: _broadcast_model]
                    2.0 KB              146                 9.39                0.06                0.52                0.48
                    98.25 MB            1                   8540.60             8540.60             0.19                0.18
reduce_scatter_tensor | [Caller Func: reduce_scatter_fn]
                    678.86 MB           80                  1527.17             13.94               1211.75             1136.01

可以通过向 deepspeed.comm.log_summary() 调用提供可选参数 show_straggler=True 来显示落后者效应。 落后者效应定义为一个秩等待最慢的秩开始通信的时间。 对于每个集体,log_summary 将获取所有秩的最小集体时间,并按以下方式计算落后者效应。

straggler = sum(t_collectives - allreduce(t_collectives, MIN))

使用以下 log_summary 调用在上面的示例中打印落后者效应。

    dist.log_summary(show_straggler=True)

更新: