本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
启用检查点
启用检查点后,SageMaker AI 将检查点保存到 Amazon S3,并将训练作业与 S3 存储桶中的检查点同步。检查点 S3 存储桶可以使用 S3 通用存储桶或 S3 目录存储桶。
以下示例显示在构造 SageMaker AI 估算器时如何配置检查点路径。要启用检查点,将 checkpoint_s3_uri 和 checkpoint_local_path 参数添加到估算器。
以下示例模板显示如何创建通用的 SageMaker AI 估算器和启用检查点。通过指定 image_uri 参数,可以将此模板用于支持的算法。要为 SageMaker AI 支持检查点的算法查找 Docker 映像 URI,请参阅 Docker 注册表路径和示例代码。您也可以将 estimator 和 Estimator 替换为其他 SageMaker AI 框架的估算器父类和估算器类,例如,TensorFlow、PyTorch、MXNet、HuggingFace 和 XGBoost。
import sagemaker from sagemaker.estimatorimportEstimatorbucket=sagemaker.Session().default_bucket() base_job_name="sagemaker-checkpoint-test" checkpoint_in_bucket="checkpoints" # The S3 URI to store the checkpoints checkpoint_s3_bucket="s3://{}/{}/{}".format(bucket, base_job_name, checkpoint_in_bucket) # The local path where the model will save its checkpoints in the training container checkpoint_local_path="/opt/ml/checkpoints" estimator =Estimator( ... image_uri="<ecr_path>/<algorithm-name>:<tag>" # Specify to use built-in algorithms output_path=bucket, base_job_name=base_job_name, # Parameters required to enable checkpointing checkpoint_s3_uri=checkpoint_s3_bucket, checkpoint_local_path=checkpoint_local_path )
以下两个参数指定检查点的路径:
-
checkpoint_local_path– 指定模型定期在训练容器中保存检查点的本地路径。默认路径设置为'/opt/ml/checkpoints'。如果您使用的是其他框架或自带训练容器,请确保训练脚本的检查点配置指定'/opt/ml/checkpoints'路径。注意
我们建议将本地路径指定为
'/opt/ml/checkpoints',从而与默认的 SageMaker AI 检查点设置保持一致。如果您更希望指定自己的本地路径,请确保与训练脚本中的检查点保存路径以及 SageMaker AI 估算器的checkpoint_local_path参数相匹配。 -
checkpoint_s3_uri– 实时存储检查点的 S3 存储桶的 URI。您可以指定 S3 通用存储桶或 S3 目录存储桶来存储检查点。有关 S3 目录存储桶的更多信息,请参阅《Amazon Simple Storage Service 用户指南》中的目录存储桶。
要查找 SageMaker AI 估算器参数的完整列表,请参阅 Amazon SageMaker Python SDK