第 1 步:编译模型 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

第 1 步:编译模型

一旦您满意了先决条件,您可以使用 Amazon SageMaker Neo 编译模型。你可以使用Amazon CLI、控制台或适用于 Python (Boto3) 的 Amazon Web Services 软件开发工具包,请参阅使用 Neo 编译模型. 在此示例中,您将使用 Boto3 编译模型。

要编译模型, SageMaker Neo 需要以下信息:

  1. 存储训练模型的 Amazon S3 存储桶 URI。

    如果您遵循了先决条件,则存储桶的名称将存储在名为的变量中bucket. 以下代码段说明了如何使用Amazon CLI:

    aws s3 ls

    例如:

    $ aws s3 ls 2020-11-02 17:08:50 bucket
  2. 您希望将编译后的模型保存到的 Amazon S3 存储桶 URI。

    以下代码段将您的 Amazon S3 存储桶 URI 与名为的输出目录的名称连接起来。output

    s3_output_location = f's3://{bucket}/output'
  3. 你用来训练模型的机器学习框架。

    定义您用于训练模型的框架。

    framework = 'framework-name'

    例如,如果你想编译使用 TensorFlow 训练的模型,你可以使用tflite要么tensorflow. 使用tflite如果您希望使用更轻的版本 TensorFlow 这使用的存储内存更少。

    framework = 'tflite'

    有关 Neo 支持框架的完整列表,请参阅支持的框架、设备、系统和架构.

  4. 模型输入的形状。

    Neo 需要输入张量的名称和形状。名称和形状以键值对的形式传递到。value是输入张量的整数维度的列表,key是模型中输入张量的确切名称。

    data_shape = '{"name": [tensor-shape]}'

    例如:

    data_shape = '{"normalized_input_image_tensor":[1, 300, 300, 3]}'
    注意

    确保模型格式正确,具体取决于您使用的框架。请参阅输入数据形状有什么作用 SageMaker Neo 期望? 本字典中的键必须更改为新的输入张量的名称。

  5. 要编译的目标设备的名称或硬件平台的一般详细信息

    target_device = 'target-device-name'

    例如,如果要部署到树莓派 3,请使用:

    target_device = 'rasp3b'

    您可以在中找到支持的边缘设备的完整列表支持的框架、设备、系统和架构.

现在你已经完成了前面的步骤,你可以向 Neo 提交编译作业了。

# Create a SageMaker client so you can submit a compilation job sagemaker_client = boto3.client('sagemaker', region_name=AWS_REGION) # Give your compilation job a name compilation_job_name = 'getting-started-demo' print(f'Compilation job for {compilation_job_name} started') response = sagemaker_client.create_compilation_job( CompilationJobName=compilation_job_name, RoleArn=role_arn, InputConfig={ 'S3Uri': s3_input_location, 'DataInputConfig': data_shape, 'Framework': framework.upper() }, OutputConfig={ 'S3OutputLocation': s3_output_location, 'TargetDevice': target_device }, StoppingCondition={ 'MaxRuntimeInSeconds': 900 } ) # Optional - Poll every 30 sec to check completion status import time while True: response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name) if response['CompilationJobStatus'] == 'COMPLETED': break elif response['CompilationJobStatus'] == 'FAILED': raise RuntimeError('Compilation failed') print('Compiling ...') time.sleep(30) print('Done!')

如果您想要进行调试的其他信息,请包括以下 print 语句:

print(response)

如果编译作业成功,您的编译模型将存储在您之前指定的输出 Amazon S3 存储桶中(s3_output_location)。在本地下载编译后的模型:

object_path = f'output/{model}-{target_device}.tar.gz' neo_compiled_model = f'compiled-{model}.tar.gz' s3_client.download_file(bucket, object_path, neo_compiled_model)