Save and load checkpoints while using SMP - Amazon SageMaker
Services or capabilities described in Amazon Web Services documentation might vary by Region. To see the differences applicable to the China Regions, see Getting Started with Amazon Web Services in China (PDF).

Save and load checkpoints while using SMP

The SMP library supports PyTorch APIs for checkpoints, and provides APIs that help checkpoint properly while using the SMP library.

PyTorch FSDP supports three types of checkpoints: full, sharded and local. These serve different purposes. Full checkpoint should ideally be used only when exporting the model after training finishes, because it’s expensive to generate a full checkpoint. Sharded checkpoint is the recommended approach for saving and loading checkpoints during training. Using sharded checkpoints you can also change the cluster size when resuming training. Local checkpoints are more restrictive. With local checkpoints, you need to resume training with same number of GPUs and currently it’s not supported when using tensor parallelism with SMP. Note that checkpoints by FSDP require writing to a shared network file system, such as FSx.

Sharded checkpoints

The following procedure highlights what you need to do to adapt your training script to save and load sharded checkpoints with or without the SMP tensor parallelism feature.

  1. Import the SMP torch.sagemaker package.

    import torch.sagemaker as tsm
  2. Set up auxiliary variables to save and load checkpoints.

    1. Set up a coordinator rank for performing communicative collective operations such as AllReduce.

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. Using the torch.sagemaker.state enumerations, set up the action rank to determine whether to let the ranks take part in checkpointing. And add an if statement for saving checkpoints depending on the usage of SMP v2 tensor parallelism.

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. Keep using the PyTorch FSDP checkpoint APIs as is.

The following code example shows a full PyTorch FSDP training script with the FSDP checkpoint APIs.

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

Full model checkpoints

At the end of training, you can save a full checkpoint that combines all shards of a model into a single model checkpoint file. The SMP library fully supports the PyTorch full model checkpoints API, so you don't need to make any changes.

Note that if you use the SMP Tensor parallelism, the SMP library transforms the model. When checkpointing the full model in this case, the SMP library translates the model back to the Hugging Face Transformers checkpoint format by default.

In cases where you train with the SMP tensor parallelism and turn off the SMP translation process, you can use the translate_on_save argument of the PyTorch FullStateDictConfig API to switch the SMP auto-translation on or off as needed. For example, if you are focusing on training a model, you don’t need to add the translation process which adds overhead. In that case, we recommend you to set translate_on_save=False. Also, if you plan to keep using the SMP translation of the model for further training in future, you can switch it off to save the SMP translation of the model for later use. Translating the model back to the Hugging Face Transformers model checkpoint format is needed when you wrap up the training of your model and use that for inference.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Note that the option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) is to gather the model on the CPU of the 0th rank device to save memory when training large models.

To load the model back for inference, you do so as shown in the following code example. Note that the class AutoModelForCausalLM might change to other factor builder classes in Hugging Face Transformers, such as AutoModelForSeq2SeqLM, depending on your model. For more information, see Hugging Face Transformers documentation.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)