步骤 1:编译模型 - Amazon SageMaker
AWS 文档中描述的 AWS 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 AWS 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

步骤 1:编译模型

在满足先决条件后,您可以使用 Amazon SageMaker Neo 编译模型。您可以使用 AWS CLI、 控制台或适用于 Python 的 Amazon Web Services 开发工具包 (Boto3 编译模型,请参阅使用 Neo 编译模型。在本示例中,您将使用 Boto3 编译模型。

要编译模型SageMaker Neo 需要以下信息:

  1. 存储训练后的模型的Amazon S3存储桶 URI。

    如果您遵循先决条件,则存储桶的名称将存储在名为 的变量中bucket。以下代码段说明如何使用 列出您的所有存储桶AWS CLI:

    $ aws s3 ls

    例如:

    $ aws s3 ls 2020-11-02 17:08:50 bucket
  2. 您希望 保存已编译模型的Amazon S3存储桶 URI。

    以下代码段将您的Amazon S3存储桶 URI 与名为 的输出目录的名称连接output

    s3_output_location = f's3://{bucket}/output'
  3. 用于训练模型的机器学习框架。

    定义用于训练模型的框架。

    framework = 'framework-name'

    例如,如果您要编译使用 TensorFlow训练的模型,则可以使用 或 tflite tensorflowtflite 如果要使用使用较少存储内存的更轻型 TensorFlow 版本,请使用 。

    framework = 'tflite'

    有关 Neo 支持的框架的完整列表,请参阅支持的框架、设备、系统和架构

  4. 模型的输入的形状。

    Neo 需要输入张量的名称和形状。名称和形状作为键值对传入。value 是输入张量的整数维度列表, key 是模型中输入张量的确切名称。

    data_shape = '{"name": [tensor-shape]}'

    例如:

    data_shape = '{"normalized_input_image_tensor":[1, 300, 300, 3]}'
    注意

    确保根据您使用的框架正确格式化模型。请参阅 SageMaker Neo 需要哪些输入数据形状? 此字典中的键必须更改为新输入张量的名称。

  5. 要编译的目标设备的名称或硬件平台的一般详细信息

    target_device = 'target-device-name'

    例如,如果您要部署到 Raspberry Pi 3,请使用:

    target_device = 'rasp3b'

    您可以在 支持的框架、设备、系统和架构中找到支持的边缘设备的完整列表。

现在,您已完成前面的步骤,可以向 Neo 提交编译作业。

# Create a SageMaker client so you can submit a compilation job sagemaker_client = boto3.client('sagemaker', region_name=AWS_REGION) # Give your compilation job a name compilation_job_name = 'getting-started-demo' print(f'Compilation job for {compilation_job_name} started') response = sagemaker_client.create_compilation_job( CompilationJobName=compilation_job_name, RoleArn=role_arn, InputConfig={ 'S3Uri': s3_input_location, 'DataInputConfig': data_shape, 'Framework': framework.upper() }, OutputConfig={ 'S3OutputLocation': s3_output_location, 'TargetDevice': target_device }, StoppingCondition={ 'MaxRuntimeInSeconds': 900 } ) # Optional - Poll every 30 sec to check completion status import time while True: response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name) if response['CompilationJobStatus'] == 'COMPLETED': break elif response['CompilationJobStatus'] == 'FAILED': raise RuntimeError('Compilation failed') print('Compiling ...') time.sleep(30) print('Done!')

如果您需要其他信息以进行调试,请包含以下打印语句:

print(response)

如果编译作业成功,编译的模型将存储在您之前指定的输出Amazon S3存储桶 () s3_output_location中。在本地下载已编译的模型:

object_path = f'output/{model}-{target_device}.tar.gz' neo_compiled_model = f'compiled-{model}.tar.gz' s3_client.download_file(bucket, object_path, neo_compiled_model)