修改 PyTorch 训练脚本 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

修改 PyTorch 训练脚本

以下步骤介绍了如何将 PyTorch 训练脚本转换为利用 SageMaker 的分布式数据并行库。

库 API 设计为类似于 PyTorch 分布式数据并行 (DDP) API。有关 PyTorch 提供的每个数据并行 API 的其他详细信息,请参阅SageMaker 分布式数据并行 PyTorch API 文档.

注意

SageMaker 的分布式数据并行机制库支持开箱即用的自动混合精度 (AMP)。除了对训练脚本进行框架级修改之外,无需额外操作即可启用 AMP。如果渐变位于 FP16 中,则 SageMaker 数据并行度库会运行其AllReduce在 FP16 中的操作。有关将 AMP API 实施到训练脚本的更多信息,请参阅以下资源:

  1. 导入库的 PyTorch 客户端并对其进行初始化,然后导入模块进行分布式训练。

    import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP dist.init_process_group()
  2. 解析参数并定义批处理大小参数(例如batch_size=args.batch_size),添加 2 行代码来调整每个工作线程的批处理大小 (GPU)。PyTorch 的 DataLoader 操作不会自动处理分布式训练的批量调整大小。

    batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
  3. 将每个 GPU 固定到单个 SageMaker 数据并行库进程,使用local_rank-这是指在给定节点内进程的相对排名。

    这些区域有:smdistributed.dataparallel.torch.get_local_rank()API 为您提供设备的本地排名。领导节点为等级 0,工作线程节点为等级 1、2、3 等。这在下一个代码块中作为dist.get_local_rank().

    torch.cuda.set_device(dist.get_local_rank())
  4. 使用库的 DDP 包装 PyTorch 模型。

    model = ... # Wrap model with the library's DistributedDataParallel model = DDP(model)
  5. 修改torch.utils.data.distributed.DistributedSampler以包含集群的信息。Setnum_replicas设置为集群中所有节点上参与训练的 GPU 总数。这就是所谓的world_size. 你可以得到world_sizesmdistributed.dataparallel.torch.get_world_size()API。这在下面的代码中被调用为dist.get_world_size(). 另外,使用smdistributed.dataparallel.torch.get_rank(). 将其调用为dist.get_rank().

    train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
  6. 修改脚本,以便仅在领导节点上保存检查点。领导节点具有同步模型。这还可以避免工作节点覆盖检查点并可能损坏检查点。

以下是用于使用库进行分布式训练的 PyTorch 训练脚本示例:

# SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.distributed as dist # SageMaker data parallel: Import the library PyTorch DDP from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the library dist.init_process_group() class Net(nn.Module):     ...     # Define model def train(...):     ...     # Model training def test(...):     ...     # Model evaluation def main():          # SageMaker data parallel: Scale batch size by world size     batch_size //= dist.get_world_size()     batch_size = max(batch_size, 1)     # Prepare dataset     train_dataset = torchvision.datasets.MNIST(...)       # SageMaker data parallel: Set num_replicas and rank in DistributedSampler     train_sampler = torch.utils.data.distributed.DistributedSampler(             train_dataset,             num_replicas=dist.get_world_size(),             rank=dist.get_rank())       train_loader = torch.utils.data.DataLoader(..)       # SageMaker data parallel: Wrap the PyTorch model with the library's DDP     model = DDP(Net().to(device))          # SageMaker data parallel: Pin each GPU to a single library process.     torch.cuda.set_device(local_rank)     model.cuda(local_rank)          # Train     optimizer = optim.Adadelta(...)     scheduler = StepLR(...)     for epoch in range(1, args.epochs + 1):         train(...)         if rank == 0:             test(...)         scheduler.step()     # SageMaker data parallel: Save model on master node.     if dist.get_rank() == 0:         torch.save(...) if __name__ == '__main__':     main()

有关更多高级使用情况,请参阅SageMaker 分布式数据并行 PyTorch API 文档.

在完成训练脚本调整后,请继续执行下一主题:运行 SageMaker 分布式数据并行培训 Job使用 SageMaker Python 开发工具包。