激活检查点 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

激活检查点

激活检查点(或梯度检查点)技术通过清除某些层的激活并在向后传递期间重新计算它们,来减少内存使用量。实际上,这是用额外的计算时间来换取内存使用量的减少。如果对模块执行了检查点操作,则在向前传递结束时,该模块的输入和输出将保留在内存中。在向前传递期间,任何本应是模块内部计算一部分的中间张量都会被释放。在有检查点的模块的向后传递过程中,会重新计算这些张量。此时,有检查点的模块之外的层已经完成其向后传递,因此检查点操作的峰值内存使用量可能会更低。

注意

此功能在 SageMaker 模型并行库 v1.6.0 及更高版本上对 PyTorch 可用。

如何使用激活检查点

使用 smdistributed.modelparallel,您可以按模块使用激活检查点。对于除 torch.nn.Sequential 之外的所有 torch.nn 模块,只有当从管道并行性的角度来看,模块树位于一个分区内时,才能对模块树执行检查点操作。对于 torch.nn.Sequential 模块,顺序模块内的每个模块树必须完全位于一个分区内,激活检查点才能起作用。使用手动分区时,请注意这些限制。

使用自动模型分区时,您可在训练作业日志中找到以 Partition assignments: 开头的分区分配日志。如果一个模块在多个秩(例如,一个后代属于一个秩,另一个后代处于不同的秩)上分区,则库会忽略对模块执行检查点的尝试,并发出一条警告消息,说明该模块没有检查点。

注意

SageMaker 模型并行库支持将重叠和非重叠操作 allreduce 与检查点操作结合使用。

注意

PyTorch 的原生检查点 API 与 smdistributed.modelparallel 不兼容。

示例 1:以下示例代码演示了当脚本中有模型定义时,如何使用激活检查点操作。

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

示例 2:以下示例代码演示了当脚本中有顺序模型时,如何使用激活检查点操作。

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

示例 3:以下示例代码演示了从库中导入预构建模型时(例如 PyTorch 和 Hugging Face Transformers),如何使用激活检查点。无论您是否对顺序模型执行检查点操作,请完成以下过程:

  1. 使用 smp.DistributedModel() 包装模型。

  2. 为顺序层定义一个对象。

  3. 使用 smp.set_activation_checkpointig() 包装顺序层对象。

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')