对具有模型并行性的模型执行检查点操作和微调
SageMaker 模型并行性库提供了检查点 API,用于保存按各种模型并行性策略划分的模型状态和优化器状态,并加载检查点,以便从要重新开始训练和微调的地方继续进行训练。这些 API 还支持部分或完整保存模型和优化器状态的选项。
对分布式模型执行检查点操作
根据在 PyTorch 和 TensorFlow 之间选择的框架以及您使用的 SageMaker 模型并行库的版本,选择以下主题之一。
主题
对分布式 PyTorch 模型执行检查点操作(对于 SageMaker 模型并行性库 v1.10.0 和更高版本)
SageMaker 模型并行库提供了检查点 API,用于保存和加载分布式模型状态及其优化器状态的完整或部分检查点。
注意
如果您使用 PyTorch 和 SageMaker 模型并行性库 v1.10.0 或更高版本,则建议使用这种检查点方法。
部分检查点
要保存使用模型并行性训练的模型的检查点,请使用 smdistributed.modelparallel.torch.save_checkpoint
partial=True
)。这将单独保存每个模型分区。在模型状态和优化器状态之外,您还可以通过 user_content
参数保存任何其他自定义数据。具有检查点的模型、优化器和用户内容保存为单独的文件。save_checkpoint
API 调用按以下结构创建检查点文件夹。
- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)
要从部分检查点恢复训练,请使用 smdistributed.modelparallel.torch.resume_from_checkpoint
partial=True
,并指定保存部分检查点时使用的检查点目录和标签。请注意,模型权重的实际加载在模型分区之后进行,在第一次运行经过 smdistributed.modelparallel.torch.step
修饰的训练步骤函数期间。
保存部分检查点时,库还会使用 .pt
文件扩展名,将模型分区决策保存为文件。反过来,在从部分检查点恢复时,库会将分区决策文件一起加载。分区决策一旦加载,您就无法更改分区。
以下代码片段显示了如何在 PyTorch 训练脚本中设置检查点 API。
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=True
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=True
)
完整检查点
要保存最终模型构件用于推理用途,请使用 smdistributed.modelparallel.torch.save_checkpoint
API 和 partial=False
,这会组合模型分区以创建单个模型构件。请注意,这不会合并优化器状态。
要使用特定权重初始化训练,对于给定的完整模型检查点,您可以使用 smdistributed.modelparallel.torch.resume_from_checkpoint
API 和 partial=False
。请注意,这不会加载优化器状态。
注意
通常,对于张量并行性,state_dict
必须在原始模型实施和 DistributedModel
实施之间进行转换。或者,您可以将 state_dict
转换函数作为参数提供给 smdistributed.modelparallel.torch.resume_from_checkpoint
。但是,对于 现成支持的模型,库会自动处理此转换。
以下代码示例显示了如何使用检查点 API,对使用模型并行性训练的 PyTorch 模型执行完整检查点操作。
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=False
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=False
)
对分布式 PyTorch 模型(适用于 v1.6.0 和 v1.9.0 之间的 SageMaker 模型并行性库)执行检查点操作
SageMaker 模型并行性库提供了 Python 函数,用于对使用张量并行性的训练作业保存部分或完整检查点。以下过程演示如何使用 smp.save()
smp.load()
注意
如果您使用版本在 v1.6.0 和 v1.9.0 之间的 PyTorch、张量并行性 和 SageMaker 模型并行库,则推荐使用这种检查点方法。
-
准备一个模型对象,并使用库的包装器函数
smp.DistributedModel()
进行包装。model = MyModel(...) model = smp.DistributedModel(model)
-
为模型准备一个优化器。一组模型参数是优化器函数所需的可迭代参数。要准备一组模型参数,您必须处理
model.parameters()
以向各个模型参数分配唯一的 ID。如果可迭代的模型参数中存在具有重复 ID 的参数,则加载具有检查点的优化器状态将失败。要为优化器创建具有唯一 ID 的可迭代模型参数,请参阅以下内容:
unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
-
使用库的包装器函数
smp.DistributedOptimizer()
包装优化器。optimizer = smp.DistributedOptimizer(optimizer)
-
使用
smp.save()
保存模型和优化器状态。根据您保存检查点的方式,选择以下两个选项之一: -
选项 1:在每个
mp_rank
上为单个MP_GROUP
保存一个部分模型。model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )
使用张量并行性,该库按以下命名格式保存检查点文件:
checkpoint.pt_{pp_rank}_{tp_rank}
。注意
对于张量并行性,请确保将 if 语句设置
if smp.rdp_rank() == 0
为而不是if smp.dp_rank() == 0
。当使用张量并行性对优化器状态进行分片时,所有缩减数据并行秩都必须保存自己的优化器状态分区。为检查点操作使用错误的 if 语句可能会导致训练作业停滞。有关在没有张量并行性的情况下使用if smp.dp_rank() == 0
的更多信息,请参阅《SageMaker Python SDK 文档》中的保存和加载的一般说明。
-
选项 2:保存完整模型。
if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
注意
对于完整检查点操作,请考虑以下内容:
-
如果您设置
gather_to_rank0=True
,则除0
之外的所有秩返回空字典。 -
对于完整检查点,您只能对模型执行检查点操作。目前不支持对优化器状态执行完整检查点操作。
-
只需将完整模型保存在
smp.rank() == 0
即可。
-
-
-
使用
smp.load()
加载检查点。根据您在上一步中的检查点操作方式,选择以下两个选项之一: -
选项 1:加载部分检查点。
checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
如果您知道分区不会更改,则可以在
model.load_state_dict()
中设置same_partition_load=True
以实现更快的加载速度。 -
选项 2:加载完整检查点。
if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])
if smp.rdp_rank() == 0
条件并非必需,但它可以帮助避免不同MP_GROUP
之间的冗余加载。使用张量并行性时,目前不支持完整检查点优化器状态字典。
-
对分布式 TensorFlow 模型执行检查点操作
要在使用模型并行性训练时保存 TensorFlow 模型,请使用 SageMaker 模型并行性库提供的以下函数。
微调分布式模型
需要在训练脚本中配置微调。以下代码片段显示了一个训练脚本的示例结构,脚本使用 Hugging Face Transformers 的 AutoModelForCausalLMsmdistributed.model.parallel.torch
模块和设置用于微调。
注意
在激活 smp.delayed_param_initializationsmp.DistributedModel()
包装的转换器模型),需要使用 FSx for Lustre 文件系统配置微调作业。如果要使用延迟参数初始化选项对大规模的模型进行微调,您应设置 FSx for Lustre 文件系统。
import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)
有关训练脚本和 Jupyter 笔记本的完整示例,请参阅 SageMaker 示例 GitHub 存储库中的适用于 PyTorch 的 GPT-2 示例