激活检查点 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

激活检查点

激活检查点是一种通过清除某些层的激活并在向后传递期间重新计算它们来减少内存使用量的技术。实际上,这会用额外的计算时间来减少内存使用量。如果对模块进行了检查点检查,则在正向传递结束时,只有该模块的初始输入和该模块的最终输出会保留在内存中。 PyTorch 在向前传递期间,释放作为该模块内部计算一部分的任何中间张量。在检查点模块的向后传递过程中, PyTorch 重新计算这些张量。此时,该检查点模块之外的层已经完成了向后传递,因此检查点操作的峰值内存使用量会降低。

SMP v2 支持 PyTorch 激活检查点模块。apply_activation_checkpointing以下是 Hugging Face GPT-Neox 模型的激活检查点示例。

Hugging Face GPT-Neox 模型的 Checkpointing Transformer 层

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Hugging Face Gpt-Neox 模型的每其他 Transformer 层都要进行检查点检查

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

或者, PyTorch 还有检查点torch.utils.checkpoint模块,Hugging Face Transformers 模型的子集使用该模块。该模块也适用于 SMP v2。但是,它要求您有权访问模型定义才能添加检查点封装器。因此,我们建议您使用该apply_activation_checkpointing方法。