注册模型版本 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

注册模型版本

您可以通过创建一个模型版本(用于指定模型所属的模型组)来注册 Amazon SageMaker 模型。模型版本必须同时包含模型构件(模型的训练权重)和模型的推理代码。

推理管道 是一个 SageMaker 模型,由 2 到 15 个容器的线性序列组成,用于处理推理请求。您可以通过指定容器和关联的环境变量来注册推理管道。有关推理管道的更多信息,请参阅托管模型以及预处理逻辑,作为端点后面的串行推理管道

您可以通过指定容器和关联的环境变量来注册带有推理管道的模型。要通过使用 Amazon SDK for Python (Boto3) 或在 SageMaker 建模管线中创建步骤来创建带有推理管道的模型版本,请使用以下步骤。

注册模型版本 (SageMaker Pipelines)

要使用 SageMaker 建模管线注册模型版本,请在管道中创建 RegisterModel 步骤。有关作为管道的一部分创建 RegisterModel 步骤的信息,请参阅第 8 步:定义 RegisterModel 步骤以创建模型包

注册模型版本 (Boto3)

要使用 Boto3 注册模型版本,请调用 create_model_package 方法。

首先,设置要传递给 create_model_package 方法的参数字典。

# Specify the model source model_url = "s3://your-bucket-name/model.tar.gz" modelpackage_inference_specification = { "InferenceSpecification": { "Containers": [ { "Image": '257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:1.2-1', "ModelDataUrl": model_url } ], "SupportedContentTypes": [ "text/csv" ], "SupportedResponseMIMETypes": [ "text/csv" ], } } # Alternatively, you can specify the model source like this: # modelpackage_inference_specification["InferenceSpecification"]["Containers"][0]["ModelDataUrl"]=model_url create_model_package_input_dict = { "ModelPackageGroupName" : model_package_group_name, "ModelPackageDescription" : "Model to detect 3 different types of irises (Setosa, Versicolour, and Virginica)", "ModelApprovalStatus" : "PendingManualApproval" } create_model_package_input_dict.update(modelpackage_inference_specification)

然后调用 create_model_package 方法,传入刚才设置的参数字典。

create_model_package_response = sm_client.create_model_package(**create_model_package_input_dict) model_package_arn = create_model_package_response["ModelPackageArn"] print('ModelPackage Version ARN : {}'.format(model_package_arn))

注册模型版本 (Amazon SageMaker Studio)

要在 Amazon SageMaker Studio 中创建模型版本,请完成以下步骤。

  1. 登录 Amazon SageMaker Studio。有关更多信息,请参阅 加入 Amazon SageMaker 域

  2. 在左侧导航窗格中,选择主页图标 ( )。

  3. 选择模型,然后选择模型注册表

  4. 打开注册版本表单。您可以通过两种方式之一来执行此操作:

    • 选择操作,然后选择创建模型版本

    • 选择要为其创建模型版本的模型组的名称,然后选择创建模型版本

  5. 注册模型版本表单中,输入以下信息:

    • 模型包组名称下拉列表中,选择模型组名称。

    • (可选)为模型版本输入描述。

    • 模型批准状态下拉列表中,选择版本批准状态。

    • (可选)在自定义元数据字段中,以键值对形式添加自定义标签。

  6. 选择下一步

  7. 推理规范表单中,输入以下信息:

    • 输入您的推理映像位置。

    • 输入您的模型数据构件位置。

    • (可选)输入有关转换作业和实时推理作业的首选映像以及支持的输入和输出 MIME 类型的信息。

  8. 选择下一步

  9. (可选)提供详细信息以帮助推荐端点。

  10. 选择下一步

  11. (可选)选择要包含的模型指标。

  12. 选择下一步

  13. 确保显示的设置正确无误,然后选择注册模型版本。如果您随后看到带有错误消息的模型窗口,请选择查看(消息旁边)以查看错误的来源。

  14. 确认您的新模型版本出现在父模型组页面中。

从其他账户注册模型版本

要使用不同 Amazon 账户创建的模型组注册模型版本,必须添加跨账户 Amazon Identity and Access Management 资源策略以启用该账户。例如,您组织中的一个 Amazon 账户负责训练模型,另一个账户负责管理、部署和更新模型。您可以创建 IAM 资源策略,并将这些策略应用于您要针对此使用案例授予访问权限的特定账户资源。有关 Amazon 中跨账户资源策略的更多信息,请参阅《Amazon Identity and Access Management 用户指南》中的跨账户策略评估逻辑

注意

在跨账户模型部署训练期间,您还必须使用 KMS 密钥对输出数据配置操作进行加密。

要在 SageMaker 中启用跨账户模型注册表,必须为包含模型版本的模型组提供跨账户资源策略。以下是为模型组创建跨账户策略并将这些策略应用于该特定资源的示例。

必须在源账户中设置以下配置,该账户在模型组中跨账户注册模型。在此示例中,源账户是模型训练账户,它将训练模型跨账户,然后跨账户将模型注册到模型注册表账户的模型注册表中。

该示例假设您之前定义了以下变量:

  • sm_client - SageMaker Boto3 客户端。

  • model_package_group_name - 要授予访问权限的模型组。

  • model_package_group_arn - 要授予跨账户访问权限的模型组 ARN。

  • bucket - 用于存储模型训练构件的 S3 存储桶。

用户必须拥有一个可访问 SageMaker 操作的角色(例如具有 AmazonSageMakerFullAccess 托管策略的角色),才能部署在不同账户中创建的模型。有关 SageMaker 托管策略的信息,请参阅适用于 Amazon SageMaker 的 Amazon 托管式策略

必需的 IAM 资源策略

下图显示了允许跨账户模型注册所需的策略。如图所示,这些策略需要在模型训练期间处于活动状态,才能将模型正确注册到模型注册表账户中。

以下代码示例演示了 Amazon ECR、Amazon S3 和 Amazon KMS 策略。

Amazon ECR 策略示例

{ "Version": "2012-10-17", "Statement": [ { "Sid": "AddPerm", "Effect": "Allow", "Principal": { "AWS": "arn:aws:iam::{model_registry_account}:root" }, "Action": [ "ecr:BatchGetImage", "ecr:Describe*" ] } ] }

Amazon S3 策略示例

{ "Version": "2012-10-17", "Statement": [ { "Sid": "AddPerm", "Effect": "Allow", "Principal": { "AWS": "arn:aws:iam::{model_registry_account}:root" }, "Action": [ "s3:GetObject", "s3:GetBucketAcl", "s3:GetObjectAcl" ], "Resource": "arn:aws:s3:::{bucket}/*" } ] }

Amazon KMS 策略示例

{ "Version": "2012-10-17", "Statement": [ { "Sid": "AddPerm", "Effect": "Allow", "Principal": { "AWS": "arn:aws:iam::{model_registry_account}:root" }, "Action": [ "kms:Decrypt", "kms:GenerateDataKey*" ], "Resource": "*" } ] }

将资源策略应用于账户

以下策略配置应用了上一节中讨论的策略,必须放入模型训练账户。

import json # The Model Registry account id of the Model Group model_registry_account = "111111111111" # The model training account id where training happens model_training_account = "222222222222" # 1. Create a policy for access to the ECR repository # in the model training account for the Model Registry account Model Group ecr_repository_policy = {"Version": "2012-10-17", "Statement": [{"Sid": "AddPerm", "Effect": "Allow", "Principal": { "AWS": f"arn:aws:iam::{model_registry_account}:root" }, "Action": [ "ecr:BatchGetImage", "ecr:Describe*" ] }] } # Convert the ECR policy from JSON dict to string ecr_repository_policy = json.dumps(ecr_repository_policy) # Set the new ECR policy ecr = boto3.client('ecr') response = ecr.set_repository_policy( registryId = model_training_account, repositoryName = "decision-trees-sample", policyText = ecr_repository_policy ) # 2. Create a policy in the model training account for access to the S3 bucket # where the model is present in the Model Registry account Model Group bucket_policy = {"Version": "2012-10-17", "Statement": [{"Sid": "AddPerm", "Effect": "Allow", "Principal": {"AWS": f"arn:aws:iam::{model_registry_account}:root" }, "Action": [ "s3:GetObject", "s3:GetBucketAcl", "s3:GetObjectAcl" ], "Resource": "arn:aws:s3:::{bucket}/*" }] } # Convert the S3 policy from JSON dict to string bucket_policy = json.dumps(bucket_policy) # Set the new bucket policy s3 = boto3.client("s3") response = s3.put_bucket_policy( Bucket = bucket, Policy = bucket_policy) # 3. Create the KMS grant for the key used during training for encryption # in the model training account to the Model Registry account Model Group client = boto3.client("kms") response = client.create_grant( GranteePrincipal=model_registry_account, KeyId=kms_key_id Operations=[ "Decrypt", "GenerateDataKey", ], )

需要将以下配置放入模型组所在的模型注册表账户。

# The Model Registry account id of the Model Group model_registry_account = "111111111111" # 1. Create policy to allow the model training account to access the ModelPackageGroup model_package_group_policy = {"Version": "2012-10-17", "Statement": [ { "Sid": "AddPermModelPackageVersion", "Effect": "Allow", "Principal": {"AWS": f"arn:aws:iam::{model_training_account}:root", "Action": ["sagemaker:CreateModelPackage"], "Resource": f"arn:aws:sagemaker:{region}:{model_registry_account}:model-package/{model_package_group_name}/*" } ] } # Convert the policy from JSON dict to string model_package_group_policy = json.dumps(model_package_group_policy) # Set the new policy response = sm_client.put_model_package_group_policy( ModelPackageGroupName = model_package_group_name, ResourcePolicy = model_package_group_policy)

最后,使用模型训练账户中的 create_model_package 操作跨账户注册模型包。

# Specify the model source model_url = "s3://{bucket}/model.tar.gz" #Set up the parameter dictionary to pass to the create_model_package method modelpackage_inference_specification = { "InferenceSpecification": { "Containers": [ { "Image": f"{model_training_account}.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", "ModelDataUrl": model_url } ], "SupportedContentTypes": [ "text/csv" ], "SupportedResponseMIMETypes": [ "text/csv" ], } } # Alternatively, you can specify the model source like this: # modelpackage_inference_specification["InferenceSpecification"]["Containers"][0]["ModelDataUrl"]=model_url create_model_package_input_dict = { "ModelPackageGroupName" : model_package_group_arn, "ModelPackageDescription" : "Model to detect 3 different types of irises (Setosa, Versicolour, and Virginica)", "ModelApprovalStatus" : "PendingManualApproval" } create_model_package_input_dict.update(modelpackage_inference_specification) # Create the model package in the Model Registry account create_model_package_response = sm_client.create_model_package(**create_model_package_input_dict) model_package_arn = create_model_package_response["ModelPackageArn"] print('ModelPackage Version ARN : {}'.format(model_package_arn))