设置托管分层检查点 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

设置托管分层检查点

本节包含亚马逊 SageMaker HyperPod托管分层检查点的设置流程。您将了解如何在集群上启用该功能并在训练代码中实现检查点。

先决条件

在设置托管分层检查点操作之前,请确保您具有:

  • 具有足够的 CPU 内存可用于分配检查点的 Amazon EKS HyperPod 集群

  • PyTorch 训练工作负载和 DCP 作业(两者均受支持)

  • 用于管理集群的相应 IAM 权限,包括:

    • Amazon CloudWatch 和 Amazon S3 为训练容器写入权限,以读取/写入检查点和推送指标

    • 这些权限可通过 EKS OIDC 设置进行配置

第 1 步:为集群启用托管分层检查点

重要

您必须选择使用托管分层检查点功能。

在创建或更新集群 HyperPod APIs 时,通过启用托管分层检查点功能。当您指定 TieredStorageConfig 参数时,该服务会自动安装内存管理系统。

对于新集群,您可以使用create-clusterAmazon CLI。

aws sagemaker create-cluster \ --cluster-name cluster-name \ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount": instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore": threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'

InstanceMemoryAllocationPercentage 参数指定要为检查点分配的集群内存的 percentage(int)。范围是 20-100。

步骤 2:在训练映像中安装 Python 库

Amazon SageMaker 检查点库及其依赖项添加到您的 Dockerfile 中,将其安装到您的训练镜像中:

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector

第 3 步:在训练循环中保存检查点

在训练循环中,您可以使用 PyTorch DCP 异步保存检查点。以下是如何执行此操作的示例。

import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background

步骤 4:加载恢复检查点

以下是加载检查点的示例。

# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling

验证您的托管分层检查点操作

您可以使用日志验证托管分层检查点操作。

自定义日志记录(可选)

您可以通过将自定义记录器传递给库来将检查点日志与其他日志集成。例如,您可以将自定义记录器添加到训练代码,这样库中的所有日志也会被收集到训练记录器中。

增强型服务日志记录(可选)

要增强调试和服务可见性,可以将检查点日志路径 /var/log/sagemaker_checkpointing 从容器组(pod)中挂载到主机上的路径 /var/logs/sagemaker_checkpointing。这可确保仅单独收集库特定的日志,并为服务团队提供更高的调试和支持可见性。