使用 SMP 时保存和加载检查点 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 SMP 时保存和加载检查点

SMP 库支持用于检查点 PyTorch 的 API,并提供可在使用 SMP 库时正确帮助检查点的 API。

PyTorch FSDP 支持三种类型的检查点:完整检查点、分片检查点和本地检查点。它们有不同的用途。理想情况下,只有在训练完成后导出模型时才应使用完整检查点,因为生成完整检查点的成本很高。在训练期间,建议使用分片检查点来保存和加载检查点。使用分片检查点,您还可以在恢复训练时更改集群大小。本地检查点的限制性更强。使用本地检查点,你需要使用相同数量的 GPU 恢复训练,目前在 SMP 中使用张量并行性时不支持这种方法。请注意,FSDP 的检查点需要写入共享网络文件系统,例如 FSx。

分片检查点

以下过程重点介绍在使用或不使用 SMP 张量并行度功能的情况下调整训练脚本以保存和加载分片检查点所需的操作。

  1. 导入 SMP torch.sagemaker 软件包。

    import torch.sagemaker as tsm
  2. 设置辅助变量以保存和加载检查点。

    1. 设置协调员等级,以执行交际集体操作,AllReduce例如.

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. 使用torch.sagemaker.state枚举设置操作等级,以确定是否让等级参与检查点检查。并添加一个 if 语句以保存检查点,具体取决于 SMP v2 张量并行度的使用情况。

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. 继续按原样使用 PyTorch FSDP 检查点 API。

以下代码示例显示了包含 PyTorch FSDP 检查点 API 的完整 FSDP 训练脚本。

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

完整模型的检查点

训练结束时,您可以保存一个完整的检查点,该检查点将模型的所有分片合并到一个模型检查点文件中。SMP 库完全支持 PyTorch 完整的模型检查点 API,因此您无需进行任何更改。

请注意,如果您使用 SMP张量并行性,则 SMP 库会转换模型。在这种情况下,当对完整模型进行检查点时,默认情况下,SMP 库会将模型转换回 Hugging Face Transformers 检查点格式。

如果您使用 SMP 张量并行度进行训练并关闭 SMP 翻译过程,则可以使用 PyTorch FullStateDictConfig API 的translate_on_save参数根据需要打开或关闭 SMP 自动翻译。例如,如果您专注于训练模型,则无需添加会增加开销的翻译过程。在这种情况下,我们建议您进行设置translate_on_save=False。此外,如果您计划在将来继续使用模型的 SMP 转换进行进一步训练,则可以将其关闭以保存模型的 SMP 转换以备日后使用。当你总结模型的训练并使用它进行推理时,需要将模型转换回 Hugging Face Transformers 模型检查点格式。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

请注意,在训练大型模型时,可以选择FullStateDictConfig(rank0_only=True, offload_to_cpu=True)将模型收集到第 0 级设备的 CPU 上,以节省内存。

要重新加载模型以进行推理,您可以按照以下代码示例所示进行操作。请注意,该类AutoModelForCausalLM可能会更改为 Hugging Face Transformers 中的其他因子生成器类,AutoModelForSeq2SeqLM例如,具体取决于您的模型。有关更多信息,请参阅 Hugging Face 变形金刚文档

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)