检查预测结果 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

检查预测结果

您可以使用多种方法检查异步终端节点的预测结果。一些选项包括:

  1. Amazon SNS 主题。

  2. 检查 Amazon S3 存储桶中的输出。

Amazon SNS 主题

Amazon SNS 是面向消息的应用程序的通知服务,多个订阅者通过各种传输协议(包括 HTTP、Amazon SQS 和电子邮件)请求和接收时间关键型消息的 “推送” 通知。当您使用创建终端节点时,Amazon SageMaker 异步推理会发布通知CreateEndpointConfig并指定 Amazon SNS 主题。

注意

要接收 Amazon SNS 通知,您的 IAM 角色必须具有sns:Publish权限。请参阅先决条件有关使用异步推理必须满足的要求的信息。

要使用 Amazon SNS 检查异步终端节点的预测结果,您首先需要创建主题、订阅该主题、确认订阅主题,然后记下该主题的亚马逊资源名称 (ARN)。有关如何创建、订阅和查找 Amazon SNS 主题的 Amazon ARN 的详细信息,请参阅配置 Amazon SNS.

在中提供 Amazon SNS 主题 ARNAsyncInferenceConfig使用创建终端配置时使用的终端配置CreateEndpointConfig. 您可以同时指定 Amazon SNSErrorTopicSuccessTopic.

import boto3 sagemaker_client = boto3.client('sagemaker', region_name=<aws_region>) sagemaker_client.create_endpoint_config( EndpointConfigName=<endpoint_config_name>, # You specify this name in a CreateEndpoint request. # List of ProductionVariant objects, one for each model that you want to host at this endpoint. ProductionVariants=[ { "VariantName": "variant1", # The name of the production variant. "ModelName": "model_name", "InstanceType": "ml.m5.xlarge", # Specify the compute instance type. "InitialInstanceCount": 1 # Number of instances to launch initially. } ], AsyncInferenceConfig={ "OutputConfig": { # Location to upload response outputs when no location is provided in the request. "S3OutputPath": "s3://<bucket>/<output_directory>" "NotificationConfig": { "SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name", "ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name", } } } )

创建终端节点并调用终端节点后,您会收到来自 Amazon SNS 主题的通知。例如,如果您订阅了接收来自主题的电子邮件通知,则每次调用终端节点时都会收到一封电子邮件通知。以下示例显示成功调用电子邮件通知的 JSON 内容。

{ "awsRegion":"us-east-1", "eventTime":"2022-01-25T22:46:00.608Z", "receivedTime":"2022-01-25T22:46:00.455Z", "invocationStatus":"Completed", "requestParameters":{ "contentType":"text/csv", "endpointName":"<example-endpoint>", "inputLocation":"s3://<bucket>/<input-directory>/input-data.csv" }, "responseParameters":{ "contentType":"text/csv; charset=utf-8", "outputLocation":"s3://<bucket>/<output_directory>/prediction.out" }, "inferenceId":"11111111-2222-3333-4444-555555555555", "eventVersion":"1.0", "eventSource":"aws:sagemaker", "eventName":"InferenceResult" }

检查你的 S3 存储桶

当你调用终端节点时InvokeEndpointAsync,它返回响应对象。您可以使用响应对象获取存储输出的 Amazon S3 URI。使用输出位置,您可以使用 SageMaker Python SDK SageMaker 会话类以编程方式检查输出。

下面存储的输出字典InvokeEndpointAsync作为名为响应的变量。然后使用响应变量,您可以获取 Amazon S3 输出 URI 并将其存储为名为的字符串变量output_location.

import uuid import boto3 sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=<aws_region>) # Specify the S3 URI of the input. Here, a single SVM sample input_location = "s3://bucket-name/test_point_0.libsvm" response = sagemaker_runtime.invoke_endpoint_async( EndpointName='<endpoint-name>', InputLocation=input_location, InferenceId=str(uuid.uuid4()), ContentType="text/libsvm" #Specify the content type of your data ) output_location = response['OutputLocation'] print(f"OutputLocation: {output_location}")

有关支持的内容类型的信息,请参阅用于推理的常见数据格式.

然后,您可以使用 Amazon S3 输出位置SageMaker Python SDK SageMaker 会话类读取 Amazon S3 文件。以下代码示例演示如何创建函数 (get_ouput) 它反复尝试从 Amazon S3 输出位置读取文件:

import sagemaker import urllib, time from botocore.exceptions import ClientError sagemaker_session = sagemaker.session.Session() def get_output(output_location): output_url = urllib.parse.urlparse(output_location) bucket = output_url.netloc key = output_url.path[1:] while True: try: return sagemaker_session.read_s3_file( bucket=output_url.netloc, key_prefix=output_url.path[1:]) except ClientError as e: if e.response['Error']['Code'] == 'NoSuchKey': print("waiting for output...") time.sleep(2) continue raise output = get_output(output_location) print(f"Output: {output}")