将 Amazon SageMaker 上的模型训练和托管的自定义算法与 Apache Spark 结合使用 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

将 Amazon SageMaker 上的模型训练和托管的自定义算法与 Apache Spark 结合使用

示例 1:将针对训练和推理的 Amazon SageMaker 与 Apache Spark 结合使用 中,您使用 kMeansSageMakerEstimator,因为该示例使用 Amazon SageMaker 提供的 k-means 算法进行模型训练。您可以选择使用自己的自定义算法进行模型训练。假设您已经创建了 Docker 映像,则可以创建自己的 SageMakerEstimator 并为您的自定义映像指定 Amazon Elastic Container Registry 路径。

以下示例显示如何从 SageMakerEstimator 中创建 KMeansSageMakerEstimator。在新的评估程序中,您可以显式指定训练和推理代码图像的 Docker 注册表路径。

import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

在该代码中,SageMakerEstimator 构造函数中的参数包括:

  • trainingImage - 标识包含自定义代码的训练映像的 Docker 注册表路径。

  • modelImage - 标识包含推理代码的映像的 Docker 注册表路径。

  • requestRowSerializer - 实施 com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer

    此参数序列化输入 DataFrame 中的行以将其发送到在 SageMaker 中托管的模型以用于推理。

  • responseRowDeserializer - 实施

    com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer.

    此参数将在 SageMaker 中托管的响应反序列化回 DataFrame

  • trainingSparkDataFormat - 指定 Spark 在将训练数据从 DataFrame 上传到 S3 时使用的数据格式。例如,"sagemaker" 用于 protobuf 格式,"csv" 用于逗号分隔值,"libsvm" 用于 LibSVM 格式。

您可以实施自己的 RequestRowSerializerResponseRowDeserializer 以序列化或反序列化您的推理代码支持的数据格式中的行,例如 .libsvm 或 .csv。