启用检查点 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

启用检查点

启用检查点后,SageMaker AI 将检查点保存到 Amazon S3,并将训练作业与 S3 存储桶中的检查点同步。检查点 S3 存储桶可以使用 S3 通用存储桶或 S3 目录存储桶。

训练期间写入检查点的架构图。

以下示例显示在构造 SageMaker AI 估算器时如何配置检查点路径。要启用检查点,将 checkpoint_s3_uricheckpoint_local_path 参数添加到估算器。

以下示例模板显示如何创建通用的 SageMaker AI 估算器和启用检查点。通过指定 image_uri 参数,可以将此模板用于支持的算法。要为 SageMaker AI 支持检查点的算法查找 Docker 映像 URI,请参阅 Docker 注册表路径和示例代码。您也可以将 estimatorEstimator 替换为其他 SageMaker AI 框架的估算器父类和估算器类,例如,TensorFlowPyTorchMXNetHuggingFaceXGBoost

import sagemaker from sagemaker.estimator import Estimator bucket=sagemaker.Session().default_bucket() base_job_name="sagemaker-checkpoint-test" checkpoint_in_bucket="checkpoints" # The S3 URI to store the checkpoints checkpoint_s3_bucket="s3://{}/{}/{}".format(bucket, base_job_name, checkpoint_in_bucket) # The local path where the model will save its checkpoints in the training container checkpoint_local_path="/opt/ml/checkpoints" estimator = Estimator( ... image_uri="<ecr_path>/<algorithm-name>:<tag>" # Specify to use built-in algorithms output_path=bucket, base_job_name=base_job_name, # Parameters required to enable checkpointing checkpoint_s3_uri=checkpoint_s3_bucket, checkpoint_local_path=checkpoint_local_path )

以下两个参数指定检查点的路径:

  • checkpoint_local_path – 指定模型定期在训练容器中保存检查点的本地路径。默认路径设置为 '/opt/ml/checkpoints'。如果您使用的是其他框架或自带训练容器,请确保训练脚本的检查点配置指定 '/opt/ml/checkpoints' 路径。

    注意

    我们建议将本地路径指定为 '/opt/ml/checkpoints',从而与默认的 SageMaker AI 检查点设置保持一致。如果您更希望指定自己的本地路径,请确保与训练脚本中的检查点保存路径以及 SageMaker AI 估算器的 checkpoint_local_path 参数相匹配。

  • checkpoint_s3_uri – 实时存储检查点的 S3 存储桶的 URI。您可以指定 S3 通用存储桶或 S3 目录存储桶来存储检查点。有关 S3 目录存储桶的更多信息,请参阅《Amazon Simple Storage Service 用户指南》中的目录存储桶

要查找 SageMaker AI 估算器参数的完整列表,请参阅 Amazon SageMaker Python SDK 文档中的估算器 API