将针对上的模型训练和托管的自定义算法与 Apache Spark 结合使用 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

将针对上的模型训练和托管的自定义算法与 Apache Spark 结合使用

In示例 1:将针对训练和推理的 Amazon SageMaker 与 Apache Spark 结合使用,您可以使用kMeansSageMakerEstimator因为该示例使用 Amazon SageMaker 提供的 k-means 算法进行模型训练。您可以选择使用自己的自定义算法进行模型训练。假设您已经创建了 Docker 图像,您可以创建自己的SageMakerEstimator并为您的自定义图片指定亚马逊弹性容器注册路径。

以下示例显示如何从 SageMakerEstimator 中创建 KMeansSageMakerEstimator。在新的评估程序中,您可以显式指定训练和推理代码图像的 Docker 注册表路径。

import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

在该代码中,SageMakerEstimator 构造函数中的参数包括:

  • trainingImage-标识包含自定义代码的训练图像的 Docker 注册表路径。

  • modelImage-标识包含推理代码的图像的 Docker 注册表路径。

  • requestRowSerializer— 实现com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer.

    此参数序列化输入DataFrame将其发送到 SageMaker 中托管的模型以获取推理。

  • responseRowDeserializer— 实现

    com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer.

    此参数将在 SageMaker 中托管的响应反序列化回。DataFrame.

  • trainingSparkDataFormat-指定 Spark 在从上载训练数据时使用的数据格式。DataFrame添加到 S3。例如,"sagemaker" 用于 protobuf 格式,"csv" 用于逗号分隔值,"libsvm" 用于 LibSVM 格式。

您可以实施自己的 RequestRowSerializerResponseRowDeserializer 以序列化或反序列化您的推理代码支持的数据格式中的行,例如 .libsvm 或 .csv。