SageMaker AI Spark for Scala examples
Amazon SageMaker AI provides an Apache Spark library (SageMaker AI
Spark
Download Spark for Scala
You can download the source code and examples for both Python Spark (PySpark) and
Scala libraries from the SageMaker AI
Spark
For detailed instructions on installing the SageMaker AI Spark library, see SageMaker AI
Spark
SageMaker AI Spark SDK for Scala is available in the Maven central repository. Add the Spark
library to your project by adding the following dependency to your
pom.xml file:
-
If your project is built with Maven, add the following to your pom.xml file:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency> -
If your project depends on Spark 2.1, add the following to your pom.xml file:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>
Spark for Scala example
This section provides example code that uses the Apache Spark Scala library provided
by SageMaker AI to train a model in SageMaker AI using DataFrames in your Spark cluster.
This is then followed by examples on how to Use Custom Algorithms for Model
Training and Hosting on Amazon SageMaker AI with Apache Spark and Use the SageMakerEstimator
in a Spark Pipeline.
The following example hosts the resulting model artifacts using SageMaker AI hosting services.
For more details on this example, see Getting Started: K-Means Clustering on SageMaker AI with SageMaker AI Spark SDK
-
Uses the
KMeansSageMakerEstimatorto fit (or train) a model on dataBecause the example uses the k-means algorithm provided by SageMaker AI to train a model, you use the
KMeansSageMakerEstimator. You train the model using images of handwritten single-digit numbers (from the MNIST dataset). You provide the images as an inputDataFrame. For your convenience, SageMaker AI provides this dataset in an Amazon S3 bucket.In response, the estimator returns a
SageMakerModelobject. -
Obtains inferences using the trained
SageMakerModelTo get inferences from a model hosted in SageMaker AI, you call the
SageMakerModel.transformmethod. You pass aDataFrameas input. The method transforms the inputDataFrameto anotherDataFramecontaining inferences obtained from the model.For a given input image of a handwritten single-digit number, the inference identifies a cluster that the image belongs to. For more information, see K-Means Algorithm.
import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show
The example code does the following:
-
Loads the MNIST dataset from an S3 bucket provided by SageMaker AI (
awsai-sparksdk-dataset) into a SparkDataFrame(mnistTrainingDataFrame):// Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" trainingData.show()The
showmethod displays the first 20 rows in the data frame:+-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rowsIn each row:
-
The
labelcolumn identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5. -
The
featurescolumn stores a vector (org.apache.spark.ml.linalg.Vector) ofDoublevalues. These are the 784 features of the handwritten number. (Each handwritten number is a 28 x 28-pixel image, making 784 features.)
-
-
Creates a SageMaker AI estimator (
KMeansSageMakerEstimator)The
fitmethod of this estimator uses the k-means algorithm provided by SageMaker AI to train models using an inputDataFrame. In response, it returns aSageMakerModelobject that you can use to get inferences.Note
The
KMeansSageMakerEstimatorextends the SageMaker AISageMakerEstimator, which extends the Apache SparkEstimator.val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)The constructor parameters provide information that is used for training a model and deploying it on SageMaker AI:
-
trainingInstanceTypeandtrainingInstanceCount—Identify the type and number of ML compute instances to use for model training. -
endpointInstanceType—Identifies the ML compute instance type to use when hosting the model in SageMaker AI. By default, one ML compute instance is assumed. -
endpointInitialInstanceCount—Identifies the number of ML compute instances initially backing the endpoint hosting the model in SageMaker AI. -
sagemakerRole—SageMaker AI assumes this IAM role to perform tasks on your behalf. For example, for model training, it reads data from S3 and writes training results (model artifacts) to S3.Note
This example implicitly creates a SageMaker AI client. To create this client, you must provide your credentials. The API uses these credentials to authenticate requests to SageMaker AI. For example, it uses the credentials to authenticate requests to create a training job and API calls for deploying the model using SageMaker AI hosting services.
-
After the
KMeansSageMakerEstimatorobject has been created, you set the following parameters, are used in model training:-
The number of clusters that the k-means algorithm should create during model training. You specify 10 clusters, one for each digit, 0 through 9.
-
Identifies that each input image has 784 features (each handwritten number is a 28 x 28-pixel image, making 784 features).
-
-
-
Calls the estimator
fitmethod// train val model = estimator.fit(trainingData)You pass the input
DataFrameas a parameter. The model does all the work of training the model and deploying it to SageMaker AI. For more information see, Integrate your Apache Spark application with SageMaker AI. In response, you get aSageMakerModelobject, which you can use to get inferences from your model deployed in SageMaker AI.You provide only the input
DataFrame. You don't need to specify the registry path to the k-means algorithm used for model training because theKMeansSageMakerEstimatorknows it. -
Calls the
SageMakerModel.transformmethod to get inferences from the model deployed in SageMaker AI.The
transformmethod takes aDataFrameas input, transforms it, and returns anotherDataFramecontaining inferences obtained from the model.val transformedData = model.transform(testData) transformedData.showFor simplicity, we use the same
DataFrameas input to thetransformmethod that we used for model training in this example. Thetransformmethod does the following:-
Serializes the
featurescolumn in the inputDataFrameto protobuf and sends it to the SageMaker AI endpoint for inference. -
Deserializes the protobuf response into the two additional columns (
distance_to_clusterandclosest_cluster) in the transformedDataFrame.
The
showmethod gets inferences to the first 20 rows in the inputDataFrame:+-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+You can interpret the data, as follows:
-
A handwritten number with the
label5 belongs to cluster 4 (closest_cluster). -
A handwritten number with the
label0 belongs to cluster 5. -
A handwritten number with the
label4 belongs to cluster 9. -
A handwritten number with the
label1 belongs to cluster 6.
-