修改 PyTorch 训练脚本
在 SageMaker 数据并行库 v1.4.0 及更高版本中,该库可用作 PyTorch 分布式软件包
重要
由于 SageMaker 分布式数据并行库 v1.4.0 及更高版本可用作 PyTorch 分布式的后端,因此 PyTorch 分布式软件包的以下 smdistributed API
-
smdistributed.dataparallel.torch.distributed
已弃用。改为使用 torch.distributed软件包。 -
smdistributed.dataparallel.torch.parallel.DistributedDataParallel
已弃用。改为使用 torch.nn.parallel.DistributedDataParallelAPI。
如果您需要使用库的早期版本(v1.3.0 或更早版本),请参阅 SageMaker Python SDK 文档中的已存档 SageMaker 分布式数据并行库文档
使用 SageMaker 分布式数据并行库作为 torch.distributed
的后端
要使用 SageMaker 分布式数据并行库,您唯一需要做的就是导入 SageMaker 分布式数据并行库的 PyTorch 客户端 (smdistributed.dataparallel.torch.torch_smddp
)。客户端将 smddp
注册为 PyTorch 的后端。使用 torch.distributed.init_process_group
API 初始化 PyTorch 分布式进程组时,请确保为 backend
参数指定 'smddp'
。
import smdistributed.dataparallel.torch.torch_smddp import torch.distributed as dist dist.init_process_group(backend='smddp')
注意
smddp
后端目前不支持使用 torch.distributed.new_group()
API 创建子进程组。您不能将 smddp
后端与其他进程组后端(例如 NCCL 和 Gloo)同时使用。
如果您已经有一个可以运行的 PyTorch 脚本并且只需要添加后端规范,请继续到步骤 2:使用 SageMaker Python SDK 启动 SageMaker 分布式训练作业主题中的为 PyTorch 和 TensorFlow 使用 SageMaker 框架估算器。
如果您仍需要修改训练脚本以正确使用 PyTorch 分布式软件包,请按照本页的其余过程操作。
准备 PyTorch 训练脚本用于分布式训练
以下步骤提供了有关如何准备训练脚本,以便使用 PyTorch 成功运行分布式训练作业的更多提示。
注意
在 v1.4.0 中,SageMaker 分布式数据并行库支持 torch.distributedall_reduce
、broadcast
、reduce
、all_gather
和 barrier
。
-
导入 PyTorch 分布式模块。
import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
-
解析参数并定义批次大小参数(例如
batch_size=args.batch_size
)后,添加两行代码以调整每个工作线程 (GPU) 的批次大小。PyTorch 的 DataLoader 操作不会自动处理分布式训练的批次大小调整。batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
-
使用
local_rank
将每个 GPU 固定到一个 SageMaker 数据并行库进程,这表示进程在给定节点中的相对秩。您可以从
LOCAL_RANK
环境变量中检索进程的秩。import os local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank)
-
定义模型后,使用 PyTorch
DistributedDataParallel
API 对其进行包装。model = ... # Wrap the model with the PyTorch DistributedDataParallel API model = DDP(model)
-
在您调用
torch.utils.data.distributed.DistributedSampler
API 时,请指定集群中所有节点上参与训练的进程 (GPU) 总数。这称为world_size
,您可以从torch.distributed.get_world_size()
API 中检索数量。此外,还要使用torch.distributed.get_rank()
API 指定每个进程在所有进程中的秩。from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( train_dataset, num_replicas = dist.get_world_size(), rank = dist.get_rank() )
-
修改脚本以仅在领导进程(秩 0)上保存检查点。领导进程具有同步模型。这还可以避免其他进程覆盖检查点并可能损坏检查点。
if dist.get_rank() == 0: torch.save(...)
以下示例代码显示了以 smddp
作为后端的 PyTorch 训练脚本的结构。
import os import torch # SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.torch_smddp # SageMaker data parallel: Import PyTorch's distributed API import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the process group dist.init_process_group(backend='smddp') class Net(nn.Module): ... # Define model def train(...): ... # Model training def test(...): ... # Model evaluation def main(): # SageMaker data parallel: Scale batch size by world size batch_size //= dist.get_world_size() batch_size = max(batch_size, 1) # Prepare dataset train_dataset = torchvision.datasets.MNIST(...) # SageMaker data parallel: Set num_replicas and rank in DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) train_loader = torch.utils.data.DataLoader(..) # SageMaker data parallel: Wrap the PyTorch model with the library's DDP model = DDP(Net().to(device)) # SageMaker data parallel: Pin each GPU to a single library process. local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank) model.cuda(local_rank) # Train optimizer = optim.Adadelta(...) scheduler = StepLR(...) for epoch in range(1, args.epochs + 1): train(...) if rank == 0: test(...) scheduler.step() # SageMaker data parallel: Save model on master node. if dist.get_rank() == 0: torch.save(...) if __name__ == '__main__': main()
调整完训练脚本后,继续到 步骤 2:使用 SageMaker Python SDK 启动 SageMaker 分布式训练作业。