使用 SageMaker AI 估算器来运行训练作业 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 SageMaker AI 估算器来运行训练作业

您也可以使用 SageMaker Python SDK 中的估算器来处理 SageMaker 训练作业的配置和运行。以下代码示例显示如何使用私有 Docker 注册表中的映像配置和运行估算器。

  1. 导入所需的库和依赖项,如以下代码示例中所示。

    import boto3 import sagemaker from sagemaker.estimator import Estimator session = sagemaker.Session() role = sagemaker.get_execution_role()
  2. 向您的训练映像、安全组和子网提供统一资源标识符 (URI),用于您的训练作业 VPC 配置,如以下代码示例所示。

    image_uri = "myteam.myorg.com/docker-local/my-training-image:<IMAGE-TAG>" security_groups = ["sg-0123456789abcdef0"] subnets = ["subnet-0123456789abcdef0", "subnet-0123456789abcdef0"]

    有关 security_group_idssubnets 的更多信息,请参阅 SageMaker Python SDK 的估算器部分中有关相应参数的描述。

    注意

    SageMaker AI 使用 VPC 内的网络连接来访问 Docker 注册表中的映像。要将您 Docker 注册表中的映像用于训练,注册表必须可以从您账户中的 Amazon VPC 访问。

  3. (可选)如果您的 Docker 注册表要求进行身份验证,则还必须指定向 SageMaker AI 提供访问凭证的 Amazon Lambda 函数的 Amazon 资源名称(ARN)。以下示例演示了如何指定 ARN。

    training_repository_credentials_provider_arn = "arn:aws:lambda:us-west-2:1234567890:function:test"

    有关使用需要身份验证的 Docker 注册表中的映像的更多信息,请参阅下文中的使用需要身份验证的 Docker 注册表进行训练

  4. 使用前面步骤中的代码示例来配置估算器,如以下代码示例所示。

    # The training repository access mode must be 'Vpc' for private docker registry jobs training_repository_access_mode = "Vpc" # Specify the instance type, instance count you want to use instance_type="ml.m5.xlarge" instance_count=1 # Specify the maximum number of seconds that a model training job can run max_run_time = 1800 # Specify the output path for the model artifacts output_path = "s3://your-output-bucket/your-output-path" estimator = Estimator( image_uri=image_uri, role=role, subnets=subnets, security_group_ids=security_groups, training_repository_access_mode=training_repository_access_mode, training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # remove this line if auth is not needed instance_type=instance_type, instance_count=instance_count, output_path=output_path, max_run=max_run_time )
  5. 使用您的作业名称和输入路径作为参数来调用 estimator.fit,以启动训练作业,如以下代码示例所示。

    input_path = "s3://your-input-bucket/your-input-path" job_name = "your-job-name" estimator.fit( inputs=input_path, job_name=job_name )