使用 SMP 的检查点 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 SMP 的检查点

SageMaker 模型并行度 (SMP) 库支持 PyTorch APIs 检查点,并提供了 APIs 在使用 SMP 库时正确进行检查点的帮助。

PyTorch FSDP(完全分片数据并行度)支持三种类型的检查点:完整检查点、分片检查点和本地检查点,每种检查点都有不同的用途。在训练完成后导出模型时使用完全检查点,因为生成完全检查点是一个计算成本很高的过程。分片检查点有助于保存和加载每个等级的分片模型状态。使用分片检查点,您可以使用不同的硬件配置(例如不同数量的)恢复训练。 GPUs但是,由于需要在多个设备之间进行通信,加载分片检查点的速度可能会很慢。SMP 库提供本地检查点功能,可以在不增加通信开销的情况下更快地检索模型状态。请注意,FSDP 创建的检查点需要写入共享网络文件系统,例如 Amazon FSx。

异步本地检查点

在训练机器学习模型时,后续迭代无需等待检查点文件保存到磁盘。随着 SMP v2.5 版本的发布,此库支持异步保存检查点文件。这意味着后续的训练迭代可以与用于创建检查点的输入和输出 (I/O) 操作同时运行,而不会因为这些 I/O 操作而减慢速度或停滞不前。此外,由于跨等级交换分布式张量元数据需要额外的集体通信, PyTorch 因此在中检索分片模型和优化器参数的过程可能很耗时。即使使用StateDictType.LOCAL_STATE_DICT保存每个等级的本地检查点, PyTorch 仍会调用执行集体通信的挂钩。为了缓解这一问题并减少检查点检索所需的时间,SMP 引入了 SMStateDictType.SM_LOCAL_STATE_DICT,通过绕过集体通信开销,可以更快地检索模型和优化器检查点。

注意

保持 FSDP SHARD_DEGREE 的一致性是使用 SMStateDictType.SM_LOCAL_STATE_DICT 的必要条件。确保 SHARD_DEGREE 保持不变。虽然模型复制的数量可能有所不同,但从检查点恢复时,模型分片度必须与之前的训练设置完全相同。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

下面的代码片段显示了如何利用 SMStateDictType.SM_LOCAL_STATE_DICT 加载检查点。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

存储大型语言模型的检查点 (LLMs) 可能很昂贵,因为它通常需要创建较大的文件系统容量。为了降低成本,您可以选择将检查点直接保存到 Amazon S3,而无需其他文件系统服务,例如 Amazon FSx。您可以利用前面的示例和下面的代码片段,通过指定 S3 URL 作为目标,将检查点保存到 S3。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

异步分片检查点

在某些情况下,您可能需要继续使用不同的硬件配置进行训练,例如更改硬件配置的数量 GPUs。在这种情况下,您的训练过程必须在重新分片的同时加载检查点,这意味着要使用不同数量的 SHARD_DEGREE 重新开始后续训练。为了解决需要用不同数量的 SHARD_DEGREE 恢复训练的情况,您必须使用分片状态字典类型保存模型检查点,StateDictType.SHARDED_STATE_DICT 表示分片状态字典类型。以这种格式保存检查点可以让您在使用修改后的硬件配置继续训练时正确处理重新分片过程。所提供的代码片段说明了如何使用 tsm API 异步保存分片检查点,从而实现更高效、更简化的训练过程。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

加载共享检查点的过程与上一节类似,但需要使用 torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader 及其 load 方法。通过此类的 load 方法,您可以加载共享的检查点数据,加载过程与前面描述的类似。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

完全模型检查点

在训练结束时,您可以保存一个完整的检查点,将模型的所有分片合并到一个模型检查点文件中。SMP 库完全支持 PyTorch 完整的模型检查点 API,因此您无需进行任何更改。

请注意,如果您使用 SMP 张量并行性,SMP 库会转换模型。在这种情况下对完整模型进行检查点时,SMP 库默认会将模型转换回 Hugging Face 转换器检查点格式。

如果您使用 SMP 张量并行度进行训练并关闭 SMP 翻译过程,则可以使用 PyTorch FullStateDictConfig API 的translate_on_save参数根据需要打开或关闭 SMP 自动翻译。例如,如果您专注于训练模型,就不需要添加会增加开销的转换过程。在这种情况下,我们建议您设置 translate_on_save=False。此外,如果您计划今后继续使用模型的 SMP 转换进行进一步训练,则可以将其关闭,以保存模型的 SMP 转换供以后使用。在结束模型训练并将其用于推理时,需要将模型转换回 Hugging Face 转换器模型检查点格式。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

请注意,选项 FullStateDictConfig(rank0_only=True, offload_to_cpu=True) 是在第 0 级设备的 CPU 上收集模型,以便在训练大型模型时节省内存。

要加载模型进行推理,可以按照下面的代码示例进行操作。请注意,根据您的模型,AutoModelForCausalLM 类可能会变为 Hugging Face 转换器中的其他因子构建器类,例如 AutoModelForSeq2SeqLM。有关更多信息,请参阅 Hugging Face 转换器文档

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)