检查预测结果 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

检查预测结果

您可以通过几种方法检查异步终端节点的预测结果。一些选项是:

  1. Amazon SNS 主题。

  2. 检查 Amazon S3 存储桶中的输出。

Amazon SNS 主题

Amazon SNS 是面向消息的应用程序的通知服务,多个订阅者通过选择的传输协议(包括 HTTP、Amazon SQS 和电子邮件)请求和接收时间紧迫的消息的 “推送” 通知。Amazon SageMaker 异步推理会在您使用CreateEndpointConfig,然后指定 Amazon SNS 主题。

注意

为了接收 Amazon SNS 通知,您的 IAM 角色必须具有sns:Publish权限。有关使用异步推理所必须满足的要求的信息,请参阅链接到 PREREQS。

要使用 Amazon SNS 检查异步终端节点的预测结果,您首先需要创建主题、订阅主题、确认订阅主题,并记下该主题的 Amazon 资源名称 (ARN)。有关如何创建、订阅和查找 Amazon SNS 主题的 Amazon ARN 的详细信息,请参阅。配置 Amazon SNS 主题.

提供 Amazon SNS 主题 ARN,请在AsyncInferenceConfig字段中使用创建终端节点配置CreateEndpointConfig. 您可以同时指定 Amazon SNSErrorTopicSuccessTopic.

import boto3 sagemaker_client = boto3.client('sagemaker', region_name=<aws_region>) sagemaker_client.create_endpoint_config( EndpointConfigName=<endpoint_config_name>, # You specify this name in a CreateEndpoint request. # List of ProductionVariant objects, one for each model that you want to host at this endpoint. ProductionVariants=[ { "VariantName": "variant1", # The name of the production variant. "ModelName": "model_name", "InstanceType": "ml.m5.xlarge", # Specify the compute instance type. "InitialInstanceCount": 1 # Number of instances to launch initially. } ], AsyncInferenceConfig={ "OutputConfig": { # Location to upload response outputs when no location is provided in the request. "S3OutputPath": "s3://<bucket>/<output_directory>" "NotificationConfig": { "SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name", "ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name", } } } )

检查您的 S3 存储桶

当您调用带有InvokeEndpointAsync,它会返回一个响应对象。您可以使用响应对象获取存储输出的 Amazon S3 URI。使用输出位置,您可以使用 SageMaker Python SDK SageMaker 会话类以编程方式检查输出。

以下内容存储InvokeEndpointAsync作为一个名为响应的变量。使用响应变量,您可以获取 Amazon S3 输出 URI 并将其存储为名为output_location.

import uuid import boto3 sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=<aws_region>) # Specify the S3 URI of the input. Here, a single SVM sample input_location = "s3://bucket-name/test_point_0.libsvm" response = sagemaker_runtime.invoke_endpoint_async( EndpointName='<endpoint-name>', InputLocation=input_location, InferenceId=str(uuid.uuid4()), ContentType="text/libsvm" #Specify the content type of your data ) output_location = response['OutputLocation'] print(f"OutputLocation: {output_location}")

有关支持的内容类型的信息,请参阅用于推理的常见数据格式.

通过 Amazon S3 输出位置,您可以使用SageMaker Python 软件开发工具包 SageMaker 会话类读取 Amazon S3 文件。以下代码示例演示如何创建函数(get_ouput),它反复尝试从 Amazon S3 输出位置读取文件:

import sagemaker import urllib, time from botocore.exceptions import ClientError sagemaker_session = sagemaker.session.Session() def get_output(output_location): output_url = urllib.parse.urlparse(output_location) bucket = output_url.netloc key = output_url.path[1:] while True: try: return sagemaker_session.read_s3_file( bucket=output_url.netloc, key_prefix=output_url.path[1:]) except ClientError as e: if e.response['Error']['Code'] == 'NoSuchKey': print("waiting for output...") time.sleep(2) continue raise output = get_output(output_location) print(f"Output: {output}")