将 Amazon SageMaker 上的模型训练和托管的自定义算法与 Apache Spark 结合使用 - Amazon SageMaker
AWS 文档中描述的 AWS 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 AWS 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

将 Amazon SageMaker 上的模型训练和托管的自定义算法与 Apache Spark 结合使用

示例 1:将 Amazon SageMaker 用于训练和推理与 Apache Spark 结合使用 中,您使用 kMeansSageMakerEstimator,因为该示例使用 Amazon SageMaker 提供的 k-means 算法进行模型训练。您可以选择使用自己的自定义算法进行模型训练。假设您已经创建了 Docker 镜像,您可以创建自己的 SageMakerEstimator 并为您的自定义图像指定 Amazon Elastic Container Registry 路径。

以下示例说明如何从 KMeansSageMakerEstimator 创建 SageMakerEstimator。 在新的评估程序中,您可明确指定训练和推理代码映像的 Docker 注册表路径。

import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

在该代码中,SageMakerEstimator 构造函数中的参数包括:

  • trainingImage — 标识包含自定义代码的训练图像的 Docker 注册表路径。

  • modelImage — 标识包含推理代码的图像的 Docker 注册表路径。

  • requestRowSerializer —实施 com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer

    此参数序列化输入 DataFrame 中的行以将其发送到在 SageMaker 中托管的模型以用于推理。

  • responseRowDeserializer — 实施

    com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer.

    此参数将在 SageMaker 中托管的响应反序列化回 DataFrame

  • trainingSparkDataFormat — 指定 Spark 在将训练数据从 DataFrame 上传到 S3 时使用的数据格式。例如,"sagemaker" 表示 protobuf 格式,"csv" 表示逗号分隔值,"libsvm" 表示 LibSVM 格式。

您可以实施自己的 RequestRowSerializerResponseRowDeserializer 以序列化或反序列化您的推理代码支持的数据格式中的行,例如 .libsvm 或 .csv。