示例 1:将针对训练和推理的 Amazon SageMaker 与 Apache Spark 结合使用 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

示例 1:将针对训练和推理的 Amazon SageMaker 与 Apache Spark 结合使用

Amazon SageMaker 提供一个 Apache Spark 库(在 Python 和 Scala 中),您可以使用它将您的 Apache Spark 应用程序与 SageMaker 集成。例如,您可以使用 Apache Spark 进行数据预处理,使用 SageMaker 进行模型训练和托管。有关更多信息,请参阅 将 Apache Spark 与 Amazon SageMaker 结合使用。本节提供的示例代码使用 SageMaker 提供的 Apache Spark Scala 库,用以在您的 Spark 集群中使用 DataFrame 来在 SageMaker 中训练模型。此示例还使用 SageMaker 托管服务来托管生成的模型构件。具体来说,此示例执行以下操作:

  • 对数据使用 KMeansSageMakerEstimator 调整 (或训练) 模型

     

    由于示例使用 SageMaker 提供的 k-means 算法来训练模型,因此您可以使用 KMeansSageMakerEstimator。您使用手写体单个位数数字的图像 (来自 MNIST 数据集) 来训练模型。您提供图像作为输入 DataFrame。为方便起见,SageMaker 在 S3 存储桶中提供此数据集。

     

    作为响应,评估程序返回一个 SageMakerModel 对象。

     

  • 使用经过训练的 SageMakerModel 获取推理

     

    要从在 SageMaker 中托管的模型获取推理,您可以调用 SageMakerModel.transform 方法。您传递一个 DataFrame 作为输入。该方法将输入 DataFrame 转换为另一个包含从模型中获得的推理的 DataFrame

     

    对于手写体单个位数数字的给定输入图像,该推理标识图像所属的聚类。有关更多信息,请参阅 K-Means 算法

这是示例代码:

import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show

该代码执行以下操作:

  • 将 SageMaker 提供的 S3 存储桶中的 MNIST 数据集 (awsai-sparksdk-dataset) 加载到 Spark DataFrame (mnistTrainingDataFrame) 中:

    // Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" trainingData.show()

    show 方法显示数据帧中的前 20 行:

    +-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows

    在每行中:

    • label 列标识图像的标签。例如,如果手写数字的图像是数字 5,则标签值为 5。

    • features 列存储 org.apache.spark.ml.linalg.Vector 值的向量 (Double)。这是手写数字的 784 个特征。(每个手写数字都是一个 28x28 像素的图像,从而形成 784 个特征。)

     

  • 创建 SageMaker 估算器 (KMeansSageMakerEstimator)

    此估算器的 fit 方法通过 SageMaker 提供的 k-means 算法,使用输入 DataFrame 来训练模型。作为响应,它会返回一个 SageMakerModel 对象,此对象可用于获取推理。

    注意

    KMeansSageMakerEstimator 扩展 SageMaker SageMakerEstimator,而后者会扩展 Apache Spark Estimator

    val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)

    构造函数参数提供用于训练模型并将其部署到 SageMaker 的信息:

    • trainingInstanceTypetrainingInstanceCount - 标识要用于模型训练的机器学习计算实例的类型和数量。

       

    • endpointInstanceType - 标识在 SageMaker 中托管模型时要使用的机器学习计算实例类型。默认情况下,采用一个机器学习计算实例。

       

    • endpointInitialInstanceCount - 标识最初支持在 SageMaker 中托管模型的端点的机器学习计算实例数量。

       

    • sagemakerRole - SageMaker 代入此 IAM 角色代表您执行任务。例如,对于模型训练,它从 S3 读取数据并将训练结果(模型构件)写入到 S3 中。

      注意

      此示例隐式创建一个 SageMaker 客户端。要创建此客户端,您必须提供凭证。API 使用这些凭证来对 SageMaker 的请求进行身份验证。例如,它使用凭证对请求进行身份验证,用于创建训练作业和 API 调用,以使用 SageMaker 托管服务部署模型。

    • 创建 KMeansSageMakerEstimator 对象后,您可以设置以下参数,以便用在模型训练中:

      • k-means 算法在模型训练过程中应创建的集群数量。您可以指定 10 个聚类,每个数字(0 到 9)一个。

      • 标识每个输入图像都有 784 个特征 (每个手写数字都是 28x28 像素的图像,从而形成 784 个特征)。

       

  • 调用评估程序 fit 方法

    // train val model = estimator.fit(trainingData)

    您将输入 DataFrame 作为一个参数传递。模型进行了训练模型的所有工作,并将其部署到 SageMaker。有关更多信息,请参阅将您的 Apache Spark 应用程序与 SageMaker 集成。作为响应,您将获得一个 SageMakerModel 对象,您可以使用它从您在 SageMaker 中部署的模型获取推理。

     

    您只提供输入 DataFrame。您不需要指定用于模型训练的 k-means 算法的注册表路径,因为 KMeansSageMakerEstimator 知道它。

     

  • 调用 SageMakerModel.transform 方法以从在 SageMaker 中部署的模型获取推理。

    transform 方法采用 DataFrame 作为输入,传输它,并返回另一个包含从模型获得的推理的 DataFrame

    val transformedData = model.transform(testData) transformedData.show

    为简单起见,我们使用在本例中用于模型训练的相同 DataFrame 作为 transform 方法的输入。transform 方法执行以下操作:

    • 将输入 DataFrame 中的 features 列序列化成 protobuf 并将其发送到 SageMaker 端点以获取推理。

    • 将 protobuf 响应反序列化成转换的 distance_to_cluster 中的两个附加列 (closest_clusterDataFrame)。

    show 方法获取对输入 DataFrame 中前 20 行的推理:

    +-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+

    您可以如下所示解释数据:

    • 具有 label 5 的手写数字属于集群 4 (closest_cluster)。

    • 具有 label 0 的手写数字属于集群 5。

    • 具有 label 4 的手写数字属于集群 9。

    • 具有 label 1 的手写数字属于集群 6。

有关如何运行这些示例的更多信息,请参阅 GitHub 上的 https://github.com/aws/sagemaker-spark/blob/master/README.md