SageMaker examples using SDK for Java 2.x - Amazon SDK for Java 2.x
Services or capabilities described in Amazon Web Services documentation might vary by Region. To see the differences applicable to the China Regions, see Getting Started with Amazon Web Services in China (PDF).

SageMaker examples using SDK for Java 2.x

The following code examples show you how to perform actions and implement common scenarios by using the Amazon SDK for Java 2.x with SageMaker.

Actions are code excerpts from larger programs and must be run in context. While actions show you how to call individual service functions, you can see actions in context in their related scenarios and cross-service examples.

Scenarios are code examples that show you how to accomplish a specific task by calling multiple functions within the same service.

Each example includes a link to GitHub, where you can find instructions on how to set up and run the code in context.

Get started

The following code examples show how to get started using SageMaker.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

/** * Before running this Java V2 code example, set up your development * environment, including your credentials. * * For more information, see the following documentation topic: * * https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/get-started.html */ public class HelloSageMaker { public static void main(String[] args) { Region region = Region.US_WEST_2; SageMakerClient sageMakerClient = SageMakerClient.builder() .region(region) .build(); listBooks(sageMakerClient); sageMakerClient.close(); } public static void listBooks(SageMakerClient sageMakerClient) { try { ListNotebookInstancesResponse notebookInstancesResponse = sageMakerClient.listNotebookInstances(); List<NotebookInstanceSummary> items = notebookInstancesResponse.notebookInstances(); for (NotebookInstanceSummary item : items) { System.out.println("The notebook name is: " + item.notebookInstanceName()); } } catch (SageMakerException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } }

Actions

The following code example shows how to use CreatePipeline.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

// Create a pipeline from the example pipeline JSON. public static void setupPipeline(SageMakerClient sageMakerClient, String filePath, String roleArn, String functionArn, String pipelineName) { System.out.println("Setting up the pipeline."); JSONParser parser = new JSONParser(); // Read JSON and get pipeline definition. try (FileReader reader = new FileReader(filePath)) { Object obj = parser.parse(reader); JSONObject jsonObject = (JSONObject) obj; JSONArray stepsArray = (JSONArray) jsonObject.get("Steps"); for (Object stepObj : stepsArray) { JSONObject step = (JSONObject) stepObj; if (step.containsKey("FunctionArn")) { step.put("FunctionArn", functionArn); } } System.out.println(jsonObject); // Create the pipeline. CreatePipelineRequest pipelineRequest = CreatePipelineRequest.builder() .pipelineDescription("Java SDK example pipeline") .roleArn(roleArn) .pipelineName(pipelineName) .pipelineDefinition(jsonObject.toString()) .build(); sageMakerClient.createPipeline(pipelineRequest); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } catch (IOException | ParseException e) { throw new RuntimeException(e); } }
  • For API details, see CreatePipeline in Amazon SDK for Java 2.x API Reference.

The following code example shows how to use DeletePipeline.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

// Delete a SageMaker pipeline by name. public static void deletePipeline(SageMakerClient sageMakerClient, String pipelineName) { DeletePipelineRequest pipelineRequest = DeletePipelineRequest.builder() .pipelineName(pipelineName) .build(); sageMakerClient.deletePipeline(pipelineRequest); System.out.println("*** Successfully deleted " + pipelineName); }
  • For API details, see DeletePipeline in Amazon SDK for Java 2.x API Reference.

The following code example shows how to use DescribePipelineExecution.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

// Check the status of a pipeline execution. public static void waitForPipelineExecution(SageMakerClient sageMakerClient, String executionArn) throws InterruptedException { String status; int index = 0; do { DescribePipelineExecutionRequest pipelineExecutionRequest = DescribePipelineExecutionRequest.builder() .pipelineExecutionArn(executionArn) .build(); DescribePipelineExecutionResponse response = sageMakerClient .describePipelineExecution(pipelineExecutionRequest); status = response.pipelineExecutionStatusAsString(); System.out.println(index + ". The Status of the pipeline is " + status); TimeUnit.SECONDS.sleep(4); index++; } while ("Executing".equals(status)); System.out.println("Pipeline finished with status " + status); }

The following code example shows how to use StartPipelineExecution.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

// Start a pipeline run with job configurations. public static String executePipeline(SageMakerClient sageMakerClient, String bucketName, String queueUrl, String roleArn, String pipelineName) { System.out.println("Starting pipeline execution."); String inputBucketLocation = "s3://" + bucketName + "/samplefiles/latlongtest.csv"; String output = "s3://" + bucketName + "/outputfiles/"; Gson gson = new GsonBuilder() .setFieldNamingPolicy(FieldNamingPolicy.UPPER_CAMEL_CASE) .setPrettyPrinting().create(); // Set up all parameters required to start the pipeline. List<Parameter> parameters = new ArrayList<>(); Parameter para1 = Parameter.builder() .name("parameter_execution_role") .value(roleArn) .build(); Parameter para2 = Parameter.builder() .name("parameter_queue_url") .value(queueUrl) .build(); String inputJSON = "{\n" + " \"DataSourceConfig\": {\n" + " \"S3Data\": {\n" + " \"S3Uri\": \"s3://" + bucketName + "/samplefiles/latlongtest.csv\"\n" + " },\n" + " \"Type\": \"S3_DATA\"\n" + " },\n" + " \"DocumentType\": \"CSV\"\n" + "}"; System.out.println(inputJSON); Parameter para3 = Parameter.builder() .name("parameter_vej_input_config") .value(inputJSON) .build(); // Create an ExportVectorEnrichmentJobOutputConfig object. VectorEnrichmentJobS3Data jobS3Data = VectorEnrichmentJobS3Data.builder() .s3Uri(output) .build(); ExportVectorEnrichmentJobOutputConfig outputConfig = ExportVectorEnrichmentJobOutputConfig.builder() .s3Data(jobS3Data) .build(); String gson4 = gson.toJson(outputConfig); Parameter para4 = Parameter.builder() .name("parameter_vej_export_config") .value(gson4) .build(); System.out.println("parameter_vej_export_config:" + gson.toJson(outputConfig)); // Create a VectorEnrichmentJobConfig object. ReverseGeocodingConfig reverseGeocodingConfig = ReverseGeocodingConfig.builder() .xAttributeName("Longitude") .yAttributeName("Latitude") .build(); VectorEnrichmentJobConfig jobConfig = VectorEnrichmentJobConfig.builder() .reverseGeocodingConfig(reverseGeocodingConfig) .build(); String para5JSON = "{\"MapMatchingConfig\":null,\"ReverseGeocodingConfig\":{\"XAttributeName\":\"Longitude\",\"YAttributeName\":\"Latitude\"}}"; Parameter para5 = Parameter.builder() .name("parameter_step_1_vej_config") .value(para5JSON) .build(); System.out.println("parameter_step_1_vej_config:" + gson.toJson(jobConfig)); parameters.add(para1); parameters.add(para2); parameters.add(para3); parameters.add(para4); parameters.add(para5); StartPipelineExecutionRequest pipelineExecutionRequest = StartPipelineExecutionRequest.builder() .pipelineExecutionDescription("Created using Java SDK") .pipelineExecutionDisplayName(pipelineName + "-example-execution") .pipelineParameters(parameters) .pipelineName(pipelineName) .build(); StartPipelineExecutionResponse response = sageMakerClient.startPipelineExecution(pipelineExecutionRequest); return response.pipelineExecutionArn(); }

Scenarios

The following code example shows how to:

  • Set up resources for a pipeline.

  • Set up a pipeline that executes a geospatial job.

  • Start a pipeline execution.

  • Monitor the status of the execution.

  • View the output of the pipeline.

  • Clean up resources.

For more information, see Create and run SageMaker pipelines using Amazon SDKs on Community.aws.

SDK for Java 2.x
Note

There's more on GitHub. Find the complete example and learn how to set up and run in the Amazon Code Examples Repository.

public class SagemakerWorkflow { public static final String DASHES = new String(new char[80]).replace("\0", "-"); private static String eventSourceMapping = ""; public static void main(String[] args) throws InterruptedException { final String usage = "\n" + "Usage:\n" + " <sageMakerRoleName> <lambdaRoleName> <functionFileLocation> <functionName> <queueName> <bucketName> <lnglatData> <spatialPipelinePath> <pipelineName>\n\n" + "Where:\n" + " sageMakerRoleName - The name of the Amazon SageMaker role.\n\n" + " lambdaRoleName - The name of the AWS Lambda role.\n\n" + " functionFileLocation - The file location where the JAR file that represents the AWS Lambda function is located.\n\n" + " functionName - The name of the AWS Lambda function (for example,SageMakerExampleFunction).\n\n" + " queueName - The name of the Amazon Simple Queue Service (Amazon SQS) queue.\n\n" + " bucketName - The name of the Amazon Simple Storage Service (Amazon S3) bucket.\n\n" + " lnglatData - The file location of the latlongtest.csv file required for this use case.\n\n" + " spatialPipelinePath - The file location of the GeoSpatialPipeline.json file required for this use case.\n\n" + " pipelineName - The name of the pipeline to create (for example, sagemaker-sdk-example-pipeline).\n\n"; if (args.length != 9) { System.out.println(usage); System.exit(1); } String sageMakerRoleName = args[0]; String lambdaRoleName = args[1]; String functionFileLocation = args[2]; String functionName = args[3]; String queueName = args[4]; String bucketName = args[5]; String lnglatData = args[6]; String spatialPipelinePath = args[7]; String pipelineName = args[8]; String handlerName = "org.example.SageMakerLambdaFunction::handleRequest"; Region region = Region.US_WEST_2; SageMakerClient sageMakerClient = SageMakerClient.builder() .region(region) .build(); IamClient iam = IamClient.builder() .region(region) .build(); LambdaClient lambdaClient = LambdaClient.builder() .region(region) .build(); SqsClient sqsClient = SqsClient.builder() .region(region) .build(); S3Client s3Client = S3Client.builder() .region(region) .build(); System.out.println(DASHES); System.out.println("Welcome to the Amazon SageMaker pipeline example scenario."); System.out.println( "\nThis example workflow will guide you through setting up and running an" + "\nAmazon SageMaker pipeline. The pipeline uses an AWS Lambda function and an" + "\nAmazon SQS Queue. It runs a vector enrichment reverse geocode job to" + "\nreverse geocode addresses in an input file and store the results in an export file."); System.out.println(DASHES); System.out.println(DASHES); System.out.println("First, we will set up the roles, functions, and queue needed by the SageMaker pipeline."); String lambdaRoleArn = checkLambdaRole(iam, lambdaRoleName); String sageMakerRoleArn = checkSageMakerRole(iam, sageMakerRoleName); String functionArn = checkFunction(lambdaClient, functionName, functionFileLocation, lambdaRoleArn, handlerName); String queueUrl = checkQueue(sqsClient, lambdaClient, queueName, functionName); System.out.println("The queue URL is " + queueUrl); System.out.println(DASHES); System.out.println(DASHES); System.out.println("Setting up bucket " + bucketName); if (!checkBucket(s3Client, bucketName)) { setupBucket(s3Client, bucketName); System.out.println("Put " + lnglatData + " into " + bucketName); putS3Object(s3Client, bucketName, "latlongtest.csv", lnglatData); } System.out.println(DASHES); System.out.println(DASHES); System.out.println("Now we can create and run our pipeline."); setupPipeline(sageMakerClient, spatialPipelinePath, sageMakerRoleArn, functionArn, pipelineName); String pipelineExecutionARN = executePipeline(sageMakerClient, bucketName, queueUrl, sageMakerRoleArn, pipelineName); System.out.println("The pipeline execution ARN value is " + pipelineExecutionARN); waitForPipelineExecution(sageMakerClient, pipelineExecutionARN); System.out.println("Getting output results " + bucketName); getOutputResults(s3Client, bucketName); System.out.println(DASHES); System.out.println(DASHES); System.out.println("The pipeline has completed. To view the pipeline and runs " + "in SageMaker Studio, follow these instructions:" + "\nhttps://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-studio.html"); System.out.println(DASHES); System.out.println(DASHES); System.out.println("Do you want to delete the AWS resources used in this Workflow? (y/n)"); Scanner in = new Scanner(System.in); String delResources = in.nextLine(); if (delResources.compareTo("y") == 0) { System.out.println("Lets clean up the AWS resources. Wait 30 seconds"); TimeUnit.SECONDS.sleep(30); deleteEventSourceMapping(lambdaClient); deleteSQSQueue(sqsClient, queueName); listBucketObjects(s3Client, bucketName); deleteBucket(s3Client, bucketName); deleteLambdaFunction(lambdaClient, functionName); deleteLambdaRole(iam, lambdaRoleName); deleteSagemakerRole(iam, sageMakerRoleName); deletePipeline(sageMakerClient, pipelineName); } else { System.out.println("The AWS Resources were not deleted!"); } System.out.println(DASHES); System.out.println(DASHES); System.out.println("SageMaker pipeline scenario is complete."); System.out.println(DASHES); } private static void readObject(S3Client s3Client, String bucketName, String key) { System.out.println("Output file contents: \n"); GetObjectRequest objectRequest = GetObjectRequest.builder() .bucket(bucketName) .key(key) .build(); ResponseBytes<GetObjectResponse> objectBytes = s3Client.getObjectAsBytes(objectRequest); byte[] byteArray = objectBytes.asByteArray(); String text = new String(byteArray, StandardCharsets.UTF_8); System.out.println("Text output: " + text); } // Display some results from the output directory. public static void getOutputResults(S3Client s3Client, String bucketName) { System.out.println("Getting output results {bucketName}."); ListObjectsRequest listObjectsRequest = ListObjectsRequest.builder() .bucket(bucketName) .prefix("outputfiles/") .build(); ListObjectsResponse response = s3Client.listObjects(listObjectsRequest); List<S3Object> s3Objects = response.contents(); for (S3Object object : s3Objects) { readObject(s3Client, bucketName, object.key()); } } // Check the status of a pipeline execution. public static void waitForPipelineExecution(SageMakerClient sageMakerClient, String executionArn) throws InterruptedException { String status; int index = 0; do { DescribePipelineExecutionRequest pipelineExecutionRequest = DescribePipelineExecutionRequest.builder() .pipelineExecutionArn(executionArn) .build(); DescribePipelineExecutionResponse response = sageMakerClient .describePipelineExecution(pipelineExecutionRequest); status = response.pipelineExecutionStatusAsString(); System.out.println(index + ". The Status of the pipeline is " + status); TimeUnit.SECONDS.sleep(4); index++; } while ("Executing".equals(status)); System.out.println("Pipeline finished with status " + status); } // Delete a SageMaker pipeline by name. public static void deletePipeline(SageMakerClient sageMakerClient, String pipelineName) { DeletePipelineRequest pipelineRequest = DeletePipelineRequest.builder() .pipelineName(pipelineName) .build(); sageMakerClient.deletePipeline(pipelineRequest); System.out.println("*** Successfully deleted " + pipelineName); } // Create a pipeline from the example pipeline JSON. public static void setupPipeline(SageMakerClient sageMakerClient, String filePath, String roleArn, String functionArn, String pipelineName) { System.out.println("Setting up the pipeline."); JSONParser parser = new JSONParser(); // Read JSON and get pipeline definition. try (FileReader reader = new FileReader(filePath)) { Object obj = parser.parse(reader); JSONObject jsonObject = (JSONObject) obj; JSONArray stepsArray = (JSONArray) jsonObject.get("Steps"); for (Object stepObj : stepsArray) { JSONObject step = (JSONObject) stepObj; if (step.containsKey("FunctionArn")) { step.put("FunctionArn", functionArn); } } System.out.println(jsonObject); // Create the pipeline. CreatePipelineRequest pipelineRequest = CreatePipelineRequest.builder() .pipelineDescription("Java SDK example pipeline") .roleArn(roleArn) .pipelineName(pipelineName) .pipelineDefinition(jsonObject.toString()) .build(); sageMakerClient.createPipeline(pipelineRequest); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } catch (IOException | ParseException e) { throw new RuntimeException(e); } } // Start a pipeline run with job configurations. public static String executePipeline(SageMakerClient sageMakerClient, String bucketName, String queueUrl, String roleArn, String pipelineName) { System.out.println("Starting pipeline execution."); String inputBucketLocation = "s3://" + bucketName + "/samplefiles/latlongtest.csv"; String output = "s3://" + bucketName + "/outputfiles/"; Gson gson = new GsonBuilder() .setFieldNamingPolicy(FieldNamingPolicy.UPPER_CAMEL_CASE) .setPrettyPrinting().create(); // Set up all parameters required to start the pipeline. List<Parameter> parameters = new ArrayList<>(); Parameter para1 = Parameter.builder() .name("parameter_execution_role") .value(roleArn) .build(); Parameter para2 = Parameter.builder() .name("parameter_queue_url") .value(queueUrl) .build(); String inputJSON = "{\n" + " \"DataSourceConfig\": {\n" + " \"S3Data\": {\n" + " \"S3Uri\": \"s3://" + bucketName + "/samplefiles/latlongtest.csv\"\n" + " },\n" + " \"Type\": \"S3_DATA\"\n" + " },\n" + " \"DocumentType\": \"CSV\"\n" + "}"; System.out.println(inputJSON); Parameter para3 = Parameter.builder() .name("parameter_vej_input_config") .value(inputJSON) .build(); // Create an ExportVectorEnrichmentJobOutputConfig object. VectorEnrichmentJobS3Data jobS3Data = VectorEnrichmentJobS3Data.builder() .s3Uri(output) .build(); ExportVectorEnrichmentJobOutputConfig outputConfig = ExportVectorEnrichmentJobOutputConfig.builder() .s3Data(jobS3Data) .build(); String gson4 = gson.toJson(outputConfig); Parameter para4 = Parameter.builder() .name("parameter_vej_export_config") .value(gson4) .build(); System.out.println("parameter_vej_export_config:" + gson.toJson(outputConfig)); // Create a VectorEnrichmentJobConfig object. ReverseGeocodingConfig reverseGeocodingConfig = ReverseGeocodingConfig.builder() .xAttributeName("Longitude") .yAttributeName("Latitude") .build(); VectorEnrichmentJobConfig jobConfig = VectorEnrichmentJobConfig.builder() .reverseGeocodingConfig(reverseGeocodingConfig) .build(); String para5JSON = "{\"MapMatchingConfig\":null,\"ReverseGeocodingConfig\":{\"XAttributeName\":\"Longitude\",\"YAttributeName\":\"Latitude\"}}"; Parameter para5 = Parameter.builder() .name("parameter_step_1_vej_config") .value(para5JSON) .build(); System.out.println("parameter_step_1_vej_config:" + gson.toJson(jobConfig)); parameters.add(para1); parameters.add(para2); parameters.add(para3); parameters.add(para4); parameters.add(para5); StartPipelineExecutionRequest pipelineExecutionRequest = StartPipelineExecutionRequest.builder() .pipelineExecutionDescription("Created using Java SDK") .pipelineExecutionDisplayName(pipelineName + "-example-execution") .pipelineParameters(parameters) .pipelineName(pipelineName) .build(); StartPipelineExecutionResponse response = sageMakerClient.startPipelineExecution(pipelineExecutionRequest); return response.pipelineExecutionArn(); } public static void deleteEventSourceMapping(LambdaClient lambdaClient) { DeleteEventSourceMappingRequest eventSourceMappingRequest = DeleteEventSourceMappingRequest.builder() .uuid(eventSourceMapping) .build(); lambdaClient.deleteEventSourceMapping(eventSourceMappingRequest); } public static void deleteSagemakerRole(IamClient iam, String roleName) { String[] sageMakerRolePolicies = getSageMakerRolePolicies(); try { for (String policy : sageMakerRolePolicies) { // First the policy needs to be detached. DetachRolePolicyRequest rolePolicyRequest = DetachRolePolicyRequest.builder() .policyArn(policy) .roleName(roleName) .build(); iam.detachRolePolicy(rolePolicyRequest); } // Delete the role. DeleteRoleRequest roleRequest = DeleteRoleRequest.builder() .roleName(roleName) .build(); iam.deleteRole(roleRequest); System.out.println("*** Successfully deleted " + roleName); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } public static void deleteLambdaRole(IamClient iam, String roleName) { String[] lambdaRolePolicies = getLambdaRolePolicies(); try { for (String policy : lambdaRolePolicies) { // First the policy needs to be detached. DetachRolePolicyRequest rolePolicyRequest = DetachRolePolicyRequest.builder() .policyArn(policy) .roleName(roleName) .build(); iam.detachRolePolicy(rolePolicyRequest); } // Delete the role. DeleteRoleRequest roleRequest = DeleteRoleRequest.builder() .roleName(roleName) .build(); iam.deleteRole(roleRequest); System.out.println("*** Successfully deleted " + roleName); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } // Delete the specific AWS Lambda function. public static void deleteLambdaFunction(LambdaClient awsLambda, String functionName) { try { DeleteFunctionRequest request = DeleteFunctionRequest.builder() .functionName(functionName) .build(); awsLambda.deleteFunction(request); System.out.println("*** " + functionName + " was deleted"); } catch (LambdaException e) { System.err.println(e.getMessage()); System.exit(1); } } // Delete the specific S3 bucket. public static void deleteBucket(S3Client s3Client, String bucketName) { DeleteBucketRequest deleteBucketRequest = DeleteBucketRequest.builder() .bucket(bucketName) .build(); s3Client.deleteBucket(deleteBucketRequest); System.out.println("*** " + bucketName + " was deleted."); } public static void listBucketObjects(S3Client s3, String bucketName) { try { ListObjectsRequest listObjects = ListObjectsRequest .builder() .bucket(bucketName) .build(); ListObjectsResponse res = s3.listObjects(listObjects); List<S3Object> objects = res.contents(); for (S3Object myValue : objects) { System.out.print("\n The name of the key is " + myValue.key()); deleteBucketObjects(s3, bucketName, myValue.key()); } } catch (S3Exception e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } public static void deleteBucketObjects(S3Client s3, String bucketName, String objectName) { ArrayList<ObjectIdentifier> toDelete = new ArrayList<>(); toDelete.add(ObjectIdentifier.builder() .key(objectName) .build()); try { DeleteObjectsRequest dor = DeleteObjectsRequest.builder() .bucket(bucketName) .delete(Delete.builder() .objects(toDelete).build()) .build(); s3.deleteObjects(dor); System.out.println("*** " + bucketName + " objects were deleted."); } catch (S3Exception e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } // Delete the specific Amazon SQS queue. public static void deleteSQSQueue(SqsClient sqsClient, String queueName) { try { GetQueueUrlRequest getQueueRequest = GetQueueUrlRequest.builder() .queueName(queueName) .build(); String queueUrl = sqsClient.getQueueUrl(getQueueRequest).queueUrl(); DeleteQueueRequest deleteQueueRequest = DeleteQueueRequest.builder() .queueUrl(queueUrl) .build(); sqsClient.deleteQueue(deleteQueueRequest); } catch (SqsException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } public static void putS3Object(S3Client s3, String bucketName, String objectKey, String objectPath) { try { Map<String, String> metadata = new HashMap<>(); metadata.put("x-amz-meta-myVal", "test"); PutObjectRequest putOb = PutObjectRequest.builder() .bucket(bucketName) .key("samplefiles/" + objectKey) .metadata(metadata) .build(); s3.putObject(putOb, RequestBody.fromFile(new File(objectPath))); System.out.println("Successfully placed " + objectKey + " into bucket " + bucketName); } catch (S3Exception e) { System.err.println(e.getMessage()); System.exit(1); } } public static void setupBucket(S3Client s3Client, String bucketName) { try { S3Waiter s3Waiter = s3Client.waiter(); CreateBucketRequest bucketRequest = CreateBucketRequest.builder() .bucket(bucketName) .build(); s3Client.createBucket(bucketRequest); HeadBucketRequest bucketRequestWait = HeadBucketRequest.builder() .bucket(bucketName) .build(); // Wait until the bucket is created and print out the response. WaiterResponse<HeadBucketResponse> waiterResponse = s3Waiter.waitUntilBucketExists(bucketRequestWait); waiterResponse.matched().response().ifPresent(System.out::println); System.out.println(bucketName + " is ready"); } catch (S3Exception e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } } // Set up the SQS queue to use with the pipeline. public static String setupQueue(SqsClient sqsClient, LambdaClient lambdaClient, String queueName, String lambdaName) { System.out.println("Setting up queue named " + queueName); try { Map<QueueAttributeName, String> queueAtt = new HashMap<>(); queueAtt.put(QueueAttributeName.DELAY_SECONDS, "5"); queueAtt.put(QueueAttributeName.RECEIVE_MESSAGE_WAIT_TIME_SECONDS, "5"); queueAtt.put(QueueAttributeName.VISIBILITY_TIMEOUT, "300"); CreateQueueRequest createQueueRequest = CreateQueueRequest.builder() .queueName(queueName) .attributes(queueAtt) .build(); sqsClient.createQueue(createQueueRequest); System.out.println("\nGet queue url"); GetQueueUrlResponse getQueueUrlResponse = sqsClient .getQueueUrl(GetQueueUrlRequest.builder().queueName(queueName).build()); TimeUnit.SECONDS.sleep(15); connectLambda(sqsClient, lambdaClient, getQueueUrlResponse.queueUrl(), lambdaName); System.out.println("Queue ready with Url " + getQueueUrlResponse.queueUrl()); return getQueueUrlResponse.queueUrl(); } catch (SqsException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } catch (InterruptedException e) { throw new RuntimeException(e); } return ""; } // Connect the queue to the Lambda function as an event source. public static void connectLambda(SqsClient sqsClient, LambdaClient lambdaClient, String queueUrl, String lambdaName) { System.out.println("Connecting the Lambda function and queue for the pipeline."); String queueArn = ""; // Specify the attributes to retrieve. List<QueueAttributeName> atts = new ArrayList<>(); atts.add(QueueAttributeName.QUEUE_ARN); GetQueueAttributesRequest attributesRequest = GetQueueAttributesRequest.builder() .queueUrl(queueUrl) .attributeNames(atts) .build(); GetQueueAttributesResponse response = sqsClient.getQueueAttributes(attributesRequest); Map<String, String> queueAtts = response.attributesAsStrings(); for (Map.Entry<String, String> queueAtt : queueAtts.entrySet()) { System.out.println("Key = " + queueAtt.getKey() + ", Value = " + queueAtt.getValue()); queueArn = queueAtt.getValue(); } CreateEventSourceMappingRequest eventSourceMappingRequest = CreateEventSourceMappingRequest.builder() .eventSourceArn(queueArn) .functionName(lambdaName) .build(); CreateEventSourceMappingResponse response1 = lambdaClient.createEventSourceMapping(eventSourceMappingRequest); eventSourceMapping = response1.uuid(); System.out.println("The mapping between the event source and Lambda function was successful"); } // Create an AWS Lambda function. public static String createLambdaFunction(LambdaClient awsLambda, String functionName, String filePath, String role, String handler) { try { LambdaWaiter waiter = awsLambda.waiter(); InputStream is = new FileInputStream(filePath); SdkBytes fileToUpload = SdkBytes.fromInputStream(is); FunctionCode code = FunctionCode.builder() .zipFile(fileToUpload) .build(); CreateFunctionRequest functionRequest = CreateFunctionRequest.builder() .functionName(functionName) .description("SageMaker example function.") .code(code) .handler(handler) .runtime(Runtime.JAVA11) .timeout(200) .memorySize(1024) .role(role) .build(); // Create a Lambda function using a waiter. CreateFunctionResponse functionResponse = awsLambda.createFunction(functionRequest); GetFunctionRequest getFunctionRequest = GetFunctionRequest.builder() .functionName(functionName) .build(); WaiterResponse<GetFunctionResponse> waiterResponse = waiter.waitUntilFunctionExists(getFunctionRequest); waiterResponse.matched().response().ifPresent(System.out::println); System.out.println("The function ARN is " + functionResponse.functionArn()); return functionResponse.functionArn(); } catch (LambdaException | FileNotFoundException e) { System.err.println(e.getMessage()); System.exit(1); } return ""; } public static String createSageMakerRole(IamClient iam, String roleName) { String[] sageMakerRolePolicies = getSageMakerRolePolicies(); System.out.println("Creating a role to use with SageMaker."); String assumeRolePolicy = "{" + "\"Version\": \"2012-10-17\"," + "\"Statement\": [{" + "\"Effect\": \"Allow\"," + "\"Principal\": {" + "\"Service\": [" + "\"sagemaker.amazonaws.com\"," + "\"sagemaker-geospatial.amazonaws.com\"," + "\"lambda.amazonaws.com\"," + "\"s3.amazonaws.com\"" + "]" + "}," + "\"Action\": \"sts:AssumeRole\"" + "}]" + "}"; try { CreateRoleRequest request = CreateRoleRequest.builder() .roleName(roleName) .assumeRolePolicyDocument(assumeRolePolicy) .description("Created using the AWS SDK for Java") .build(); CreateRoleResponse roleResult = iam.createRole(request); // Attach the policies to the role. for (String policy : sageMakerRolePolicies) { AttachRolePolicyRequest attachRequest = AttachRolePolicyRequest.builder() .roleName(roleName) .policyArn(policy) .build(); iam.attachRolePolicy(attachRequest); } // Allow time for the role to be ready. TimeUnit.SECONDS.sleep(15); System.out.println("Role ready with ARN " + roleResult.role().arn()); return roleResult.role().arn(); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); System.exit(1); } catch (InterruptedException e) { throw new RuntimeException(e); } return ""; } private static String createLambdaRole(IamClient iam, String roleName) { String[] lambdaRolePolicies = getLambdaRolePolicies(); String assumeRolePolicy = "{" + "\"Version\": \"2012-10-17\"," + "\"Statement\": [{" + "\"Effect\": \"Allow\"," + "\"Principal\": {" + "\"Service\": [" + "\"sagemaker.amazonaws.com\"," + "\"sagemaker-geospatial.amazonaws.com\"," + "\"lambda.amazonaws.com\"," + "\"s3.amazonaws.com\"" + "]" + "}," + "\"Action\": \"sts:AssumeRole\"" + "}]" + "}"; try { CreateRoleRequest request = CreateRoleRequest.builder() .roleName(roleName) .assumeRolePolicyDocument(assumeRolePolicy) .description("Created using the AWS SDK for Java") .build(); CreateRoleResponse roleResult = iam.createRole(request); // Attach the policies to the role. for (String policy : lambdaRolePolicies) { AttachRolePolicyRequest attachRequest = AttachRolePolicyRequest.builder() .roleName(roleName) .policyArn(policy) .build(); iam.attachRolePolicy(attachRequest); } // Allow time for the role to be ready. TimeUnit.SECONDS.sleep(15); System.out.println("Role ready with ARN " + roleResult.role().arn()); return roleResult.role().arn(); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); } catch (InterruptedException e) { throw new RuntimeException(e); } return ""; } public static String checkFunction(LambdaClient lambdaClient, String functionName, String filePath, String role, String handler) { System.out.println("Create an AWS Lambda function used in this workflow."); String functionArn; try { // Does this function already exist. GetFunctionRequest functionRequest = GetFunctionRequest.builder() .functionName(functionName) .build(); GetFunctionResponse response = lambdaClient.getFunction(functionRequest); functionArn = response.configuration().functionArn(); } catch (LambdaException e) { System.err.println(e.awsErrorDetails().errorMessage()); functionArn = createLambdaFunction(lambdaClient, functionName, filePath, role, handler); } return functionArn; } // Check to see if the specific S3 bucket exists. If the S3 bucket exists, this // method returns true. public static boolean checkBucket(S3Client s3, String bucketName) { try { HeadBucketRequest headBucketRequest = HeadBucketRequest.builder() .bucket(bucketName) .build(); s3.headBucket(headBucketRequest); System.out.println(bucketName + " exists"); return true; } catch (S3Exception e) { System.err.println(e.awsErrorDetails().errorMessage()); } return false; } // Checks to see if the Amazon SQS queue exists. If not, this method creates a // new queue // and returns the ARN value. public static String checkQueue(SqsClient sqsClient, LambdaClient lambdaClient, String queueName, String lambdaName) { System.out.println("Creating a queue for this use case."); String queueUrl; try { GetQueueUrlRequest request = GetQueueUrlRequest.builder() .queueName(queueName) .build(); GetQueueUrlResponse response = sqsClient.getQueueUrl(request); queueUrl = response.queueUrl(); System.out.println(queueUrl); } catch (SqsException e) { System.err.println(e.awsErrorDetails().errorMessage()); queueUrl = setupQueue(sqsClient, lambdaClient, queueName, lambdaName); } return queueUrl; } // Checks to see if the Lambda role exists. If not, this method creates it. public static String checkLambdaRole(IamClient iam, String roleName) { System.out.println("Creating a role to for AWS Lambda to use."); String roleArn; try { GetRoleRequest roleRequest = GetRoleRequest.builder() .roleName(roleName) .build(); GetRoleResponse response = iam.getRole(roleRequest); roleArn = response.role().arn(); System.out.println(roleArn); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); roleArn = createLambdaRole(iam, roleName); } return roleArn; } // Checks to see if the SageMaker role exists. If not, this method creates it. public static String checkSageMakerRole(IamClient iam, String roleName) { System.out.println("Creating a role to for AWS SageMaker to use."); String roleArn; try { GetRoleRequest roleRequest = GetRoleRequest.builder() .roleName(roleName) .build(); GetRoleResponse response = iam.getRole(roleRequest); roleArn = response.role().arn(); System.out.println(roleArn); } catch (IamException e) { System.err.println(e.awsErrorDetails().errorMessage()); roleArn = createSageMakerRole(iam, roleName); } return roleArn; } private static String[] getSageMakerRolePolicies() { String[] sageMakerRolePolicies = new String[3]; sageMakerRolePolicies[0] = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"; sageMakerRolePolicies[1] = "arn:aws:iam::aws:policy/" + "AmazonSageMakerGeospatialFullAccess"; sageMakerRolePolicies[2] = "arn:aws:iam::aws:policy/AmazonSQSFullAccess"; return sageMakerRolePolicies; } private static String[] getLambdaRolePolicies() { String[] lambdaRolePolicies = new String[5]; lambdaRolePolicies[0] = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"; lambdaRolePolicies[1] = "arn:aws:iam::aws:policy/AmazonSQSFullAccess"; lambdaRolePolicies[2] = "arn:aws:iam::aws:policy/service-role/" + "AmazonSageMakerGeospatialFullAccess"; lambdaRolePolicies[3] = "arn:aws:iam::aws:policy/service-role/" + "AmazonSageMakerServiceCatalogProductsLambdaServiceRolePolicy"; lambdaRolePolicies[4] = "arn:aws:iam::aws:policy/service-role/" + "AWSLambdaSQSQueueExecutionRole"; return lambdaRolePolicies; } }