10000 RDD broadcast joins · coder-RT/spark-optimization@7ccb412 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7ccb412

Browse files
RDD broadcast joins
1 parent b5b3f1f commit 7ccb412

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package part4rddjoins
2+
3+
import org.apache.spark.sql.SparkSession
4+
5+
import scala.util.Random
6+
7+
/**
8+
* Shown on camera in the Spark Shell.
9+
*/
10+
object RDDBroadcastJoins {
11+
12+
val spark = SparkSession.builder()
13+
.appName("Broadcast Joins")
14+
.master("local[*]")
15+
.getOrCreate()
16+
17+
val sc = spark.sparkContext
18+
19+
val random = new Random()
20+
21+
/*
22+
Scenario: assign prizes to a wide-scale competition (10M+ people).
23+
Goal: find out who won what.
24+
*/
25+
26+
// small lookup table
27+
val prizes = sc.parallelize(List(
28+
(1, "gold"),
29+
(2, "silver"),
30+
(3, "bronze")
31+
))
32+
33+
// the competition has ended - the leaderboard is known
34+
val leaderboard = sc.parallelize(1 to 10000000).map((_, random.alphanumeric.take(8).mkString))
35+
val medalists = leaderboard.join(prizes)
36+
medalists.foreach(println) // 38s for 10M elements!
37+
38+
/*
39+
We know from SQL joins that the small RDD can be broadcast so that we can avoid the shuffle on the big RDD.
40+
However, for the RDD API, we'll have to do this manually.
41+
This lesson is more about how to actually implement the broadcasting technique on RDDs.
42+
*/
43+
44+
// need to collect the RDD locally, so that we can broadcast to the executors
45+
val medalsMap = prizes.collectAsMap()
46+
// after we do this, all executors can refer to the medalsMap locally
47+
sc.broadcast(medalsMap)
48+
// need to avoid shuffles by manually going through the partitions of the big RDD
49+
val improvedMedalists = leaderboard.mapPartitions { iterator => // iterator of all the tuples in this partition; all the tuples are local to this executor
50+
iterator.flatMap { record =>
51+
val (index, name) = record
52+
medalsMap.get(index) match { // notice you can refer to the name medalsMap, which you now have access to locally after the broadcast
53+
case None => Seq.empty
54+
case Some(medal) => Seq((name, medal))
55+
}
56+
}
57+
}
58+
59+
improvedMedalists.foreach(println) // 2s, blazing fast, no shuffles or anything at all.
60+
61+
def main(args: Array[String]): Unit = {
62+
Thread.sleep(1000000)
63+
}
64+
}

0 commit comments

Comments
 (0)
0