修改 PyTorch 训练脚本 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

修改 PyTorch 训练脚本

在 SageMaker 数据并行库 v1.4.0 及更高版本中,该库可用作 PyTorch 分布式软件包的后端选项。您只需要在训练脚本上导入一次库,并在初始化期间将其设置为 PyTorch 分布式后端。使用一行的后端规范,您可以保持 PyTorch 训练脚本不变,直接使用 PyTorch 分布式模块。要查找该库的最新 API 文档,请参阅 SageMaker Python SDK 文档中的适用于 PyTorch 的 SageMaker 分布式数据并行 API。要了解有关 PyTorch 分布式软件包和后端选项的更多信息,请参阅分布式通信软件包 – torch.distributed

重要

由于 SageMaker 分布式数据并行库 v1.4.0 及更高版本可用作 PyTorch 分布式的后端,因此 PyTorch 分布式软件包的以下 smdistributed API 已弃用。

如果您需要使用库的早期版本(v1.3.0 或更早版本),请参阅 SageMaker Python SDK 文档中的已存档 SageMaker 分布式数据并行库文档

使用 SageMaker 分布式数据并行库作为 torch.distributed 的后端

要使用 SageMaker 分布式数据并行库,您唯一需要做的就是导入 SageMaker 分布式数据并行库的 PyTorch 客户端 (smdistributed.dataparallel.torch.torch_smddp)。客户端将 smddp 注册为 PyTorch 的后端。使用 torch.distributed.init_process_group API 初始化 PyTorch 分布式进程组时,请确保为 backend 参数指定 'smddp'

import smdistributed.dataparallel.torch.torch_smddp import torch.distributed as dist dist.init_process_group(backend='smddp')
注意

smddp 后端目前不支持使用 torch.distributed.new_group() API 创建子进程组。您不能将 smddp 后端与其他进程组后端(例如 NCCL 和 Gloo)同时使用。

如果您已经有一个可以运行的 PyTorch 脚本并且只需要添加后端规范,请继续到步骤 2:使用 SageMaker Python SDK 启动 SageMaker 分布式训练作业主题中的为 PyTorch 和 TensorFlow 使用 SageMaker 框架估算器

如果您仍需要修改训练脚本以正确使用 PyTorch 分布式软件包,请按照本页的其余过程操作。

准备 PyTorch 训练脚本用于分布式训练

以下步骤提供了有关如何准备训练脚本,以便使用 PyTorch 成功运行分布式训练作业的更多提示。

注意

在 v1.4.0 中,SageMaker 分布式数据并行库支持 torch.distributed 接口的以下集合基元数据类型:all_reducebroadcastreduceall_gatherbarrier

  1. 导入 PyTorch 分布式模块。

    import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
  2. 解析参数并定义批次大小参数(例如 batch_size=args.batch_size)后,添加两行代码以调整每个工作线程 (GPU) 的批次大小。PyTorch 的 DataLoader 操作不会自动处理分布式训练的批次大小调整。

    batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
  3. 使用 local_rank 将每个 GPU 固定到一个 SageMaker 数据并行库进程,这表示进程在给定节点中的相对秩。

    您可以从 LOCAL_RANK 环境变量中检索进程的秩。

    import os local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank)
  4. 定义模型后,使用 PyTorch DistributedDataParallel API 对其进行包装。

    model = ... # Wrap the model with the PyTorch DistributedDataParallel API model = DDP(model)
  5. 在您调用 torch.utils.data.distributed.DistributedSampler API 时,请指定集群中所有节点上参与训练的进程 (GPU) 总数。这称为 world_size,您可以从 torch.distributed.get_world_size() API 中检索数量。此外,还要使用 torch.distributed.get_rank() API 指定每个进程在所有进程中的秩。

    from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( train_dataset, num_replicas = dist.get_world_size(), rank = dist.get_rank() )
  6. 修改脚本以仅在领导进程(秩 0)上保存检查点。领导进程具有同步模型。这还可以避免其他进程覆盖检查点并可能损坏检查点。

    if dist.get_rank() == 0: torch.save(...)

以下示例代码显示了以 smddp 作为后端的 PyTorch 训练脚本的结构。

import os import torch # SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.torch_smddp # SageMaker data parallel: Import PyTorch's distributed API import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the process group dist.init_process_group(backend='smddp') class Net(nn.Module): ... # Define model def train(...): ... # Model training def test(...): ... # Model evaluation def main(): # SageMaker data parallel: Scale batch size by world size batch_size //= dist.get_world_size() batch_size = max(batch_size, 1) # Prepare dataset train_dataset = torchvision.datasets.MNIST(...) # SageMaker data parallel: Set num_replicas and rank in DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) train_loader = torch.utils.data.DataLoader(..) # SageMaker data parallel: Wrap the PyTorch model with the library's DDP model = DDP(Net().to(device)) # SageMaker data parallel: Pin each GPU to a single library process. local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank) model.cuda(local_rank) # Train optimizer = optim.Adadelta(...) scheduler = StepLR(...) for epoch in range(1, args.epochs + 1): train(...) if rank == 0: test(...) scheduler.step() # SageMaker data parallel: Save model on master node. if dist.get_rank() == 0: torch.save(...) if __name__ == '__main__': main()

调整完训练脚本后,继续到 步骤 2:使用 SageMaker Python SDK 启动 SageMaker 分布式训练作业