本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
SageMaker 斯卡拉的 Spark 示例
亚马逊 SageMaker 提供了一个 Apache Spark 库 (SageMakerSpark
下载 Spark for Scala
你可以从 Spark GitHub 存储库中下载 Python Spark (PySpark) 和 Scala 库的源代码和示例。SageMaker
有关安装 SageMaker Spark 库的详细说明,请参阅 SageMakerSpark
SageMaker Spark SDK for Scala 已在 Maven 中央存储库中找到。通过向 pom.xml
文件添加以下依赖项,将 Spark 库添加到项目中:
-
如果您的项目是使用 Maven 构建的,请将以下内容添加到您的 pom.xml 文件中:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency>
-
如果你的项目依赖于 Spark 2.1,请将以下内容添加到你的 pom.xml 文件中:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>
Scala 的火花示例
本节提供的示例代码使用提供的 Apache Spark Scala 库 SageMaker 来训练模型在 Spark SageMaker 集DataFrame
群中使用 s。接下来是关于如何使用 Apache Spark 在亚马逊上使用自定义算法 SageMaker 进行模型训练和托管和的示例 SageMakerEstimator在 Spark 管道中使用。
以下示例使用 SageMaker 托管服务托管生成的模型工件。有关此示例的更多详细信息,请参阅入门:使用 SageMaker Spark 开启 K-Mean SageMaker s 聚类SDK
-
对数据使用
KMeansSageMakerEstimator
调整 (或训练) 模型由于该示例使用提供的 k-means 算法 SageMaker 来训练模型,因此您可以使用。
KMeansSageMakerEstimator
您可以使用手写的个位数的图像(来自MNIST数据集)来训练模型。您提供图像作为输入DataFrame
。为方便起见,请在 Amazon S3 存储桶中 SageMaker 提供此数据集。作为响应,评估程序返回一个
SageMakerModel
对象。 -
使用经过训练的
SageMakerModel
获取推理要从中托管的模型中获取推论 SageMaker,可以调用
SageMakerModel.transform
方法。您传递一个DataFrame
作为输入。该方法将输入DataFrame
转换为另一个包含从模型中获得的推理的DataFrame
。对于手写体单个位数数字的给定输入图像,该推理标识图像所属的聚类。有关更多信息,请参阅 K-Means 算法。
import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show
该示例代码执行以下操作:
-
将MNIST数据集从 SageMaker (
awsai-sparksdk-dataset
) 提供的 S3 存储桶加载到 SparkDataFrame
(mnistTrainingDataFrame
) 中:// Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" trainingData.show()show
方法显示数据帧中的前 20 行:+-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows
在每行中:
-
label
列标识图像的标签。例如,如果手写数字的图像是数字 5,则标签值为 5。 -
features
列存储org.apache.spark.ml.linalg.Vector
值的向量 (Double
)。这是手写数字的 784 个特征。(每个手写数字都是一个 28x28 像素的图像,从而形成 784 个特征。)
-
-
创建 SageMaker 估算器 ()
KMeansSageMakerEstimator
该估计器的
fit
方法使用提供的 k 均值算法,使用输入 SageMaker 来训练模型。DataFrame
作为响应,它会返回一个SageMakerModel
对象,此对象可用于获取推理。注意
扩展了 SageMaker
SageMakerEstimator
,它KMeansSageMakerEstimator
扩展了 Apache SparkEstimator
。val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)
构造函数参数提供了用于训练模型并将其部署到以下位置的信息 SageMaker:
-
trainingInstanceType
和trainingInstanceCount
- 标识要用于模型训练的机器学习计算实例的类型和数量。 -
endpointInstanceType
— 标识托管模型时要使用的 ML 计算实例类型。 SageMaker默认情况下,采用一个机器学习计算实例。 -
endpointInitialInstanceCount
— 标识最初支持托管模型的终端节点的 ML 计算实例的数量。 SageMaker -
sagemakerRole
— SageMaker 担任此IAM角色代表您执行任务。例如,对于模型训练,它从 S3 读取数据并将训练结果(模型构件)写入到 S3 中。注意
此示例隐式创建了一个 SageMaker 客户端。要创建此客户端,您必须提供凭证。API使用这些凭证对发出的请求进行身份验证 SageMaker。例如,它使用凭证来验证创建训练作业的请求和使用 SageMaker 托管服务部署模型的API呼叫。
-
创建
KMeansSageMakerEstimator
对象后,您可以设置以下参数,以便用在模型训练中:-
k-means 算法在模型训练过程中应创建的集群数量。您可以指定 10 个聚类,每个数字(0 到 9)一个。
-
标识每个输入图像都有 784 个特征 (每个手写数字都是 28x28 像素的图像,从而形成 784 个特征)。
-
-
-
调用评估程序
fit
方法// train val model = estimator.fit(trainingData)
您将输入
DataFrame
作为一个参数传递。模型完成了训练模型并将其部署到其中的所有工作 SageMaker。有关更多信息,请参阅将你的 Apache Spark 应用程序与 SageMaker。作为响应,你会得到一个SageMakerModel
对象,你可以用它从部署在中的模型中 SageMaker获得推论。您只提供输入
DataFrame
。您不需要指定用于模型训练的 k-means 算法的注册表路径,因为KMeansSageMakerEstimator
知道它。 -
调用
SageMakerModel.transform
方法以从中部署的模型中获取推论。 SageMakertransform
方法采用DataFrame
作为输入,传输它,并返回另一个包含从模型获得的推理的DataFrame
。val transformedData = model.transform(testData) transformedData.show
为简单起见,我们使用在本例中用于模型训练的相同
DataFrame
作为transform
方法的输入。transform
方法执行以下操作:-
将输入
DataFrame
中的features
列序列化为 protobuf 并将其发送到 SageMaker 端点进行推理。 -
将 protobuf 响应反序列化成转换的
distance_to_cluster
中的两个附加列 (closest_cluster
和DataFrame
)。
show
方法获取对输入DataFrame
中前 20 行的推理:+-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+
您可以如下所示解释数据:
-
具有
label
5 的手写数字属于集群 4 (closest_cluster
)。 -
具有
label
0 的手写数字属于集群 5。 -
具有
label
4 的手写数字属于集群 9。 -
具有
label
1 的手写数字属于集群 6。
-