编写 Spark 应用程序 - Amazon EMR
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

编写 Spark 应用程序

可使用 Scala、Java 或 Python 来编写 Spark 应用程序。Apache Spark 文档的 Spark 示例主题包含多个 Spark 应用程序示例。下面所示为三个内在支持的应用程序中的 Estimating Pi 示例。您还可以在$SPARK_HOME/examples和中查看完整的示例GitHub。有关如何为 Spark 构建 JAR 的更多信息,请参阅 Apache Spark 文档中的快速入门主题。

Scala

为避免发生 Scala 兼容性问题,建议您在为 Amazon EMR 集群编译 Spark 应用程序时使用正确的 Scala 版本的 Spark 依赖项。您应该使用的 Scala 版本取决于您的集群上安装的 Spark 版本。例如,Amazon EMR 发行版 5.30.1 使用 Spark 2.4.5,该版本是使用 Scala 2.11 构建的。如果您的集群使用 Amazon EMR 发行版 5.30.1,请使用 Scala 2.11 的 Spark 依赖项。有关 Spark 使用的 Scala 版本的更多信息,请参阅 Apache Spark 文档

package org.apache.spark.examples import scala.math.random import org.apache.spark._ /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { val conf = new SparkConf().setAppName("Spark Pi") val spark = new SparkContext(conf) val slices = if (args.length > 0) args(0).toInt else 2 val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow val count = spark.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 }.reduce(_ + _) println("Pi is roughly " + 4.0 * count / n) spark.stop() } }

Java

package org.apache.spark.examples; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import java.util.ArrayList; import java.util.List; /** * Computes an approximation to pi * Usage: JavaSparkPi [slices] */ public final class JavaSparkPi { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; List<Integer> l = new ArrayList<Integer>(n); for (int i = 0; i < n; i++) { l.add(i); } JavaRDD<Integer> dataSet = jsc.parallelize(l, slices); int count = dataSet.map(new Function<Integer, Integer>() { @Override public Integer call(Integer integer) { double x = Math.random() * 2 - 1; double y = Math.random() * 2 - 1; return (x * x + y * y < 1) ? 1 : 0; } }).reduce(new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer integer, Integer integer2) { return integer + integer2; } }); System.out.println("Pi is roughly " + 4.0 * count / n); jsc.stop(); } }

Python

import argparse import logging from operator import add from random import random from pyspark.sql import SparkSession logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") def calculate_pi(partitions, output_uri): """ Calculates pi by testing a large number of random numbers against a unit circle inscribed inside a square. The trials are partitioned so they can be run in parallel on cluster instances. :param partitions: The number of partitions to use for the calculation. :param output_uri: The URI where the output is written, typically an Amazon S3 bucket, such as 's3://example-bucket/pi-calc'. """ def calculate_hit(_): x = random() * 2 - 1 y = random() * 2 - 1 return 1 if x**2 + y**2 < 1 else 0 tries = 100000 * partitions logger.info( "Calculating pi with a total of %s tries in %s partitions.", tries, partitions ) with SparkSession.builder.appName("My PyPi").getOrCreate() as spark: hits = ( spark.sparkContext.parallelize(range(tries), partitions) .map(calculate_hit) .reduce(add) ) pi = 4.0 * hits / tries logger.info("%s tries and %s hits gives pi estimate of %s.", tries, hits, pi) if output_uri is not None: df = spark.createDataFrame([(tries, hits, pi)], ["tries", "hits", "pi"]) df.write.mode("overwrite").json(output_uri) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--partitions", default=2, type=int, help="The number of parallel partitions to use when calculating pi.", ) parser.add_argument( "--output_uri", help="The URI where output is saved, typically an S3 bucket." ) args = parser.parse_args() calculate_pi(args.partitions, args.output_uri)