From d38bd2b67858c013bcdc5f140fd3a2732c846258 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 24 Sep 2025 12:52:09 +0200 Subject: [PATCH 01/17] better TriangleCount --- .../examples/TriangleCountExample.java | 90 +++++++++++++++ .../org/graphframes/lib/TriangleCount.scala | 108 +++++++++++------- 2 files changed, 159 insertions(+), 39 deletions(-) create mode 100644 core/src/main/java/org/graphframes/examples/TriangleCountExample.java diff --git a/core/src/main/java/org/graphframes/examples/TriangleCountExample.java b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java new file mode 100644 index 000000000..ff8fc2089 --- /dev/null +++ b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java @@ -0,0 +1,90 @@ +package org.graphframes.examples; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.StorageLevel; +import org.graphframes.GraphFrame; +import org.graphframes.lib.TriangleCount; + +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * The TriangleCount class demonstrates how to use the GraphFrames library in Apache Spark + * to count triangles in a graph dataset. A triangle in a graph is defined as a set of + * three interconnected vertices. + *

+ * This examples uses graphs from the LDBC Graphalytics benchmark datasets. + * The first argument is the name of the benchmark dataset, the second argument is the path where datasets are stored. + */ +public class TriangleCountExample { + public static void main(String[] args) { + String benchmarkName; + if (args.length > 0) { + benchmarkName = args[0]; + } else { + benchmarkName = "kgs"; + } + + Path resourcesPath; + if (args.length > 1) { + resourcesPath = Paths.get(args[1]); + } else { + resourcesPath = Paths.get("/tmp/ldbc_graphalitics_datesets"); + } + + Path caseRoot = resourcesPath.resolve(benchmarkName); + SparkConf sparkConf = new SparkConf() + .setAppName("TriangleCountExample") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); + SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate(); + SparkContext context = spark.sparkContext(); + context.setLogLevel("ERROR"); + context.setCheckpointDir("/tmp/graphframes-checkpoints"); + + LDBCUtils.downloadLDBCIfNotExists(resourcesPath, benchmarkName); + StructField[] edgeFields = new StructField[]{ + new StructField("src", DataTypes.LongType, true, Metadata.empty()), + new StructField("dst", DataTypes.LongType, true, Metadata.empty()) + }; + Dataset edges = spark.read() + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(new StructType(edgeFields)) + .load(caseRoot.resolve(benchmarkName + ".e").toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER()); + System.out.println("Edges loaded: " + edges.count()); + + StructField[] vertexFields = new StructField[]{ + new StructField("id", DataTypes.LongType, true, Metadata.empty()), + }; + Dataset vertices = spark.read() + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(new StructType(vertexFields)) + .load(caseRoot.resolve(benchmarkName + ".v").toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER()); + System.out.println("Vertices loaded: " + vertices.count()); + + var start = System.currentTimeMillis(); + GraphFrame graph = GraphFrame.apply(vertices, edges); + TriangleCount counter = graph.triangleCount(); + Dataset triangles = counter.run(); + + triangles.show(20, false); + long triangleCount = triangles.select(functions.sum("count")).first().getLong(0); + System.out.println("Found triangles: " + triangleCount); + var end = System.currentTimeMillis(); + System.out.println("Total running time in seconds: " + (end - start) / 1000.0); + } +} diff --git a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala index a3bec9402..41adb729f 100644 --- a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala +++ b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala @@ -18,17 +18,11 @@ package org.graphframes.lib import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.array -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.explode -import org.apache.spark.sql.functions.when +import org.apache.spark.sql.functions._ +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame -import org.graphframes.GraphFrame.DST -import org.graphframes.GraphFrame.ID -import org.graphframes.GraphFrame.LONG_DST -import org.graphframes.GraphFrame.LONG_SRC -import org.graphframes.GraphFrame.SRC -import org.graphframes.GraphFrame.quote +import org.graphframes.Logging +import org.graphframes.WithIntermediateStorageLevel /** * Computes the number of triangles passing through each vertex. @@ -36,48 +30,84 @@ import org.graphframes.GraphFrame.quote * This algorithm ignores edge direction; i.e., all edges are treated as undirected. In a * multigraph, duplicate edges will be counted only once. * - * Note that this provides the same algorithm as GraphX, but GraphX assumes the user provides a - * graph in the correct format. In Spark 2.0+, GraphX can automatically canonicalize the graph to - * put it in this format. + * **WARNING** This implementation is based on intersections of neighbor sets, which requires + * collecting both SRC and DST neighbors per edge! This will blow up memory in case the graph + * contains very high-degree nodes (power-law networks). Consider sampling strategies for that + * case! * * The returned DataFrame contains all the original vertex information and one additional column: * - count (`LongType`): the count of triangles */ -class TriangleCount private[graphframes] (private val graph: GraphFrame) extends Arguments { +class TriangleCount private[graphframes] (private val graph: GraphFrame) + extends Arguments + with Serializable + with WithIntermediateStorageLevel { def run(): DataFrame = { - TriangleCount.run(graph) + TriangleCount.run(graph, intermediateStorageLevel) } } -private object TriangleCount { +private object TriangleCount extends Logging { + import org.graphframes.GraphFrame.* - private def run(graph: GraphFrame): DataFrame = { - // Dedup edges by flipping them to have LONG_SRC < LONG_DST - // TODO (when we drop support for Spark 1.4): Use functions greatest, smallest instead of UDFs - val dedupedE = graph.indexedEdges - .filter(s"$LONG_SRC != $LONG_DST") - .selectExpr( - s"if($LONG_SRC < $LONG_DST, $SRC, $DST) as $SRC", - s"if($LONG_SRC < $LONG_DST, $DST, $SRC) as $DST") - .dropDuplicates(Seq(SRC, DST)) - val g2 = GraphFrame(graph.vertices, dedupedE) + private def prepareGraph(graph: GraphFrame): GraphFrame = { + // Dedup edges by flipping them to have SRC < DST + // Remove self-loops + val dedupedE = graph.edges + .filter(col(SRC) =!= col(DST)) + .select( + when(col(SRC) < col(DST), col(SRC)).otherwise(col(DST)).as(SRC), + when(col(SRC) < col(DST), col(DST)).otherwise(col(SRC)).as(DST)) + .distinct() - // Because SRC < DST, there exists only one type of triangles: - // - Non-cycle with one edge flipped. These are counted 1 time each by motif finding. - val triangles = g2.find("(a)-[]->(b); (b)-[]->(c); (a)-[]->(c)") + // Prepare the graph with no isolated vertices. + GraphFrame(graph.vertices.select(ID), dedupedE).dropIsolatedVertices() + } + + private def run(graph: GraphFrame, intermediateStorageLevel: StorageLevel): DataFrame = { + val g2 = prepareGraph(graph) + + val verticesWithNeighbors = g2.aggregateMessages + .setIntermediateStorageLevel(intermediateStorageLevel) + .sendToSrc(AggregateMessages.dst(ID)) + .sendToDst(AggregateMessages.src(ID)) + .agg(collect_set(AggregateMessages.msg).alias("neighbors")) + + val triangles = verticesWithNeighbors + .select(col(ID), col("neighbors").alias("src_set")) + .join(g2.edges, col(ID) === col(SRC)) + .drop(ID) + .join( + verticesWithNeighbors.select(col(ID), col("neighbors").alias("dst_set")), + col(ID) === col(DST)) + .drop(ID) + // Count of common neighbors of SRC and DST + .withColumn("triplets", array_size(array_intersect(col("src_set"), col("dst_set")))) + .filter(col("triplets") > lit(0)) + .persist(intermediateStorageLevel) + + val srcTriangles = triangles.groupBy(SRC).agg(sum(col("triplets")).alias("src_triplets")) + val dstTriangles = triangles.groupBy(DST).agg(sum(col("triplets")).alias("dst_triplets")) - val triangleCounts = triangles - .select(explode(array(col("a.id"), col("b.id"), col("c.id"))).as(ID)) - .groupBy(ID) - .count() + val result = graph.vertices + .join(srcTriangles, col(ID) === col(SRC), "left_outer") + .join(dstTriangles, col(ID) === col(DST), "left_outer") + // Each triangle counted twice, so divide by 2. + .withColumn( + COUNT_ID, + floor( + when(col("src_triplets").isNull && col("dst_triplets").isNull, lit(0)) + .when(col("src_triplets").isNull, col("dst_triplets")) + .when(col("dst_triplets").isNull, col("src_triplets")) + .otherwise(col("src_triplets") + col("dst_triplets")) / lit(2))) - val v = graph.vertices - val countsCol = when(col("count").isNull, 0L).otherwise(col("count")) - val newV = v - .join(triangleCounts, v(ID) === triangleCounts(ID), "left_outer") - .select((countsCol.as(COUNT_ID) +: v.columns.map(quote).map(v.apply)).toSeq: _*) - newV + result.persist(intermediateStorageLevel) + result.count() + verticesWithNeighbors.unpersist() + triangles.unpersist() + resultIsPersistent() + result } private val COUNT_ID = "count" From d930b9a64781cb65429e38a8184950a8e5def1ab Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 26 Sep 2025 14:28:39 +0200 Subject: [PATCH 02/17] Update configuration files and improve docs - Added local Spark distribution to .gitignore - Updated .scalafix configuration for Scala 3 - Modified .scalafmt configuration for Scala 213 source compatibility - Changed import statements to use wildcard imports for consistency across files --- .gitignore | 3 ++ .scalafix.conf | 1 + .scalafmt.conf | 4 +- .../benchmarks/LDBCBenchmarkSuite.scala | 2 +- .../graphframes/GraphFramesConnectUtils.scala | 2 +- .../scala/org/graphframes/GraphFrame.scala | 8 +-- .../examples/BeliefPropagation.scala | 2 +- .../examples/ConnectedComponentsLDBC.scala | 2 +- .../org/graphframes/examples/Graphs.scala | 2 +- .../org/graphframes/examples/LDBCUtils.scala | 4 +- .../graphframes/lib/ConnectedComponents.scala | 2 +- .../org/graphframes/lib/DetectingCycles.scala | 2 +- .../graphframes/lib/GraphXConversions.scala | 4 +- .../graphframes/lib/LabelPropagation.scala | 2 +- .../scala/org/graphframes/lib/PageRank.scala | 2 +- .../lib/ParallelPersonalizedPageRank.scala | 2 +- .../scala/org/graphframes/lib/Pregel.scala | 2 +- .../org/graphframes/lib/SVDPlusPlus.scala | 2 +- .../org/graphframes/lib/ShortestPaths.scala | 2 +- .../lib/StronglyConnectedComponents.scala | 2 +- .../org/graphframes/lib/TriangleCount.scala | 2 +- .../org/graphframes/pattern/patterns.scala | 2 +- .../property/EdgePropertyGroup.scala | 2 +- .../org/graphframes/GraphFrameSuite.scala | 2 +- .../scala/org/graphframes/TestUtils.scala | 2 +- .../examples/BeliefPropagationSuite.scala | 2 +- .../org/graphframes/ldbc/TestLDBCCases.scala | 2 +- .../lib/AggregateMessagesSuite.scala | 4 +- .../lib/ConnectedComponentsSuite.scala | 4 +- .../org/graphframes/lib/PregelSuite.scala | 6 +-- .../graphframes/lib/ShortestPathsSuite.scala | 2 +- .../PropertyGraphFrameTest.scala | 2 +- docs/src/01-about/01-index.md | 50 ++++++++++++++++++- .../spark/graphframes/graphx/Graph.scala | 2 +- .../spark/graphframes/graphx/VertexRDD.scala | 4 +- .../graphx/impl/EdgePartition.scala | 2 +- .../graphx/impl/EdgePartitionBuilder.scala | 2 +- .../graphframes/graphx/impl/EdgeRDDImpl.scala | 2 +- .../graphframes/graphx/impl/GraphImpl.scala | 2 +- .../graphx/impl/ReplicatedVertexView.scala | 2 +- .../graphx/impl/RoutingTablePartition.scala | 2 +- .../impl/ShippableVertexPartition.scala | 2 +- .../graphx/impl/VertexPartition.scala | 2 +- .../graphx/impl/VertexPartitionBase.scala | 2 +- .../graphx/impl/VertexPartitionBaseOps.scala | 2 +- .../graphx/impl/VertexRDDImpl.scala | 6 +-- .../graphx/lib/ConnectedComponents.scala | 2 +- .../graphx/lib/LabelPropagation.scala | 2 +- .../graphframes/graphx/lib/PageRank.scala | 4 +- .../graphframes/graphx/lib/SVDPlusPlus.scala | 4 +- .../graphx/lib/ShortestPaths.scala | 2 +- .../lib/StronglyConnectedComponents.scala | 2 +- .../graphx/lib/TriangleCount.scala | 2 +- .../graphx/util/GraphGenerators.scala | 6 +-- .../GraphXPrimitiveKeyOpenHashMap.scala | 2 +- .../graphframes/graphx/GraphOpsSuite.scala | 2 +- .../spark/graphframes/graphx/GraphSuite.scala | 6 +-- .../graphx/impl/EdgePartitionSuite.scala | 2 +- .../graphx/impl/VertexPartitionSuite.scala | 2 +- .../graphx/lib/ConnectedComponentsSuite.scala | 4 +- .../graphx/lib/LabelPropagationSuite.scala | 2 +- .../graphx/lib/PageRankSuite.scala | 2 +- .../graphx/lib/SVDPlusPlusSuite.scala | 2 +- .../graphx/lib/ShortestPathsSuite.scala | 2 +- .../StronglyConnectedComponentsSuite.scala | 2 +- .../graphx/lib/TriangleCountSuite.scala | 2 +- 66 files changed, 136 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index 347feb3e7..7edb5dec9 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,6 @@ connect/project # Auto-generated doc /docs/src/02-quick-start/01-installation.md /docs/src/05-blog/feed.xml + +# Local spark distro +spark-* diff --git a/.scalafix.conf b/.scalafix.conf index 82bd389eb..109cee66e 100644 --- a/.scalafix.conf +++ b/.scalafix.conf @@ -6,3 +6,4 @@ rules = [ OrganizeImports ExplicitResultTypes ] +OrganizeImports.targetDialect = Scala3 \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf index 347451104..94765a629 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -11,5 +11,5 @@ optIn = { danglingParentheses.preset = false docstrings.style = Asterisk maxColumn = 98 -runner.dialect = scala213 -version = 3.8.5 \ No newline at end of file +runner.dialect = Scala213Source3 +version = 3.8.5 diff --git a/benchmarks/src/main/scala/org/graphframes/benchmarks/LDBCBenchmarkSuite.scala b/benchmarks/src/main/scala/org/graphframes/benchmarks/LDBCBenchmarkSuite.scala index 524a880f2..004d8fb84 100644 --- a/benchmarks/src/main/scala/org/graphframes/benchmarks/LDBCBenchmarkSuite.scala +++ b/benchmarks/src/main/scala/org/graphframes/benchmarks/LDBCBenchmarkSuite.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.examples.LDBCUtils -import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.annotations.* import org.openjdk.jmh.infra.Blackhole import java.io.File diff --git a/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index 561f588ee..5ce4d967d 100644 --- a/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -17,7 +17,7 @@ import org.graphframes.connect.proto.GraphFramesAPI.MethodCase import org.graphframes.connect.proto.StringOrLongID import org.graphframes.connect.proto.StringOrLongID.IdCase -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* object GraphFramesConnectUtils { private[graphframes] def parseColumnOrExpression( diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index 69e2b92ff..5a748d2f4 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -20,7 +20,7 @@ package org.graphframes import org.apache.spark.graphframes.graphx.Edge import org.apache.spark.graphframes.graphx.Graph import org.apache.spark.ml.clustering.PowerIterationClustering -import org.apache.spark.sql._ +import org.apache.spark.sql.* import org.apache.spark.sql.functions.array import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.functions.col @@ -31,10 +31,10 @@ import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.functions.struct -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.* import org.apache.spark.storage.StorageLevel -import org.graphframes.lib._ -import org.graphframes.pattern._ +import org.graphframes.lib.* +import org.graphframes.pattern.* import java.util.Random import scala.reflect.runtime.universe.TypeTag diff --git a/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala b/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala index 7c0c53784..13bdb4e7c 100644 --- a/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala +++ b/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala @@ -17,9 +17,9 @@ package org.graphframes.examples +import org.apache.spark.graphframes.graphx.Edge as GXEdge import org.apache.spark.graphframes.graphx.Graph import org.apache.spark.graphframes.graphx.VertexRDD -import org.apache.spark.graphframes.graphx.{Edge => GXEdge} import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession diff --git a/core/src/main/scala/org/graphframes/examples/ConnectedComponentsLDBC.scala b/core/src/main/scala/org/graphframes/examples/ConnectedComponentsLDBC.scala index e0606b215..5c520ef9c 100644 --- a/core/src/main/scala/org/graphframes/examples/ConnectedComponentsLDBC.scala +++ b/core/src/main/scala/org/graphframes/examples/ConnectedComponentsLDBC.scala @@ -9,7 +9,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame -import java.nio.file._ +import java.nio.file.* import java.util.Properties object ConnectedComponentsLDBC { diff --git a/core/src/main/scala/org/graphframes/examples/Graphs.scala b/core/src/main/scala/org/graphframes/examples/Graphs.scala index f442a30d1..eba5bb43a 100644 --- a/core/src/main/scala/org/graphframes/examples/Graphs.scala +++ b/core/src/main/scala/org/graphframes/examples/Graphs.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.randn import org.apache.spark.sql.functions.udf import org.graphframes.GraphFrame -import org.graphframes.GraphFrame._ +import org.graphframes.GraphFrame.* import scala.reflect.runtime.universe.TypeTag diff --git a/core/src/main/scala/org/graphframes/examples/LDBCUtils.scala b/core/src/main/scala/org/graphframes/examples/LDBCUtils.scala index d34d044fe..c73044741 100644 --- a/core/src/main/scala/org/graphframes/examples/LDBCUtils.scala +++ b/core/src/main/scala/org/graphframes/examples/LDBCUtils.scala @@ -1,8 +1,8 @@ package org.graphframes.examples import java.net.URL -import java.nio.file._ -import scala.sys.process._ +import java.nio.file.* +import scala.sys.process.* object LDBCUtils { private val LDBC_URL_PREFIX = "https://datasets.ldbcouncil.org/graphalytics/" diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 62b2b59eb..2574969a8 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -20,7 +20,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.sql.graphframes.GraphFramesConf import org.apache.spark.sql.types.DecimalType import org.apache.spark.storage.StorageLevel diff --git a/core/src/main/scala/org/graphframes/lib/DetectingCycles.scala b/core/src/main/scala/org/graphframes/lib/DetectingCycles.scala index 6a3c407dc..340ba81ba 100644 --- a/core/src/main/scala/org/graphframes/lib/DetectingCycles.scala +++ b/core/src/main/scala/org/graphframes/lib/DetectingCycles.scala @@ -1,7 +1,7 @@ package org.graphframes.lib import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.sql.types.ArrayType import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame diff --git a/core/src/main/scala/org/graphframes/lib/GraphXConversions.scala b/core/src/main/scala/org/graphframes/lib/GraphXConversions.scala index e356b1bbc..6efe2b7b2 100644 --- a/core/src/main/scala/org/graphframes/lib/GraphXConversions.scala +++ b/core/src/main/scala/org/graphframes/lib/GraphXConversions.scala @@ -20,13 +20,13 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx.Graph import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.graphframes.GraphFrame import org.graphframes.NoSuchVertexException -import scala.reflect.runtime.universe._ +import scala.reflect.runtime.universe.* /** * Convenience functions to map GraphX graphs to GraphFrames, checking for the types expected by diff --git a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala index f65d6d76f..a50fa1a63 100644 --- a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -20,7 +20,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType import org.apache.spark.storage.StorageLevel diff --git a/core/src/main/scala/org/graphframes/lib/PageRank.scala b/core/src/main/scala/org/graphframes/lib/PageRank.scala index 22ea38ce9..6faa11ce9 100644 --- a/core/src/main/scala/org/graphframes/lib/PageRank.scala +++ b/core/src/main/scala/org/graphframes/lib/PageRank.scala @@ -17,7 +17,7 @@ package org.graphframes.lib -import org.apache.spark.graphframes.graphx.{lib => graphxlib} +import org.apache.spark.graphframes.graphx.lib as graphxlib import org.graphframes.GraphFrame import org.graphframes.Logging diff --git a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 2b5db23cd..072d5b93d 100644 --- a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -17,7 +17,7 @@ package org.graphframes.lib -import org.apache.spark.graphframes.graphx.{lib => graphxlib} +import org.apache.spark.graphframes.graphx.lib as graphxlib import org.graphframes.GraphFrame import org.graphframes.Logging import org.graphframes.WithMaxIter diff --git a/core/src/main/scala/org/graphframes/lib/Pregel.scala b/core/src/main/scala/org/graphframes/lib/Pregel.scala index 8a17037b1..0d63fe8f8 100644 --- a/core/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/core/src/main/scala/org/graphframes/lib/Pregel.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.functions.explode import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.struct import org.graphframes.GraphFrame -import org.graphframes.GraphFrame._ +import org.graphframes.GraphFrame.* import org.graphframes.Logging import org.graphframes.WithIntermediateStorageLevel import org.graphframes.WithLocalCheckpoints diff --git a/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index 97b445678..e8457f59e 100644 --- a/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/core/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -18,7 +18,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx.Edge -import org.apache.spark.graphframes.graphx.{lib => graphxlib} +import org.apache.spark.graphframes.graphx.lib as graphxlib import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.graphframes.GraphFrame diff --git a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala index 599fbb1c0..c1146015e 100644 --- a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -43,7 +43,7 @@ import org.graphframes.WithIntermediateStorageLevel import org.graphframes.WithLocalCheckpoints import java.util -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* /** * Computes shortest paths from every vertex to the given set of landmark vertices. Note that this diff --git a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index fbcd6242a..82f79e68e 100644 --- a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -17,7 +17,7 @@ package org.graphframes.lib -import org.apache.spark.graphframes.graphx.{lib => graphxlib} +import org.apache.spark.graphframes.graphx.lib as graphxlib import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame diff --git a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala index 41adb729f..378b033b5 100644 --- a/core/src/main/scala/org/graphframes/lib/TriangleCount.scala +++ b/core/src/main/scala/org/graphframes/lib/TriangleCount.scala @@ -18,7 +18,7 @@ package org.graphframes.lib import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.Logging diff --git a/core/src/main/scala/org/graphframes/pattern/patterns.scala b/core/src/main/scala/org/graphframes/pattern/patterns.scala index 0d9c3f345..ddd4ba8d5 100644 --- a/core/src/main/scala/org/graphframes/pattern/patterns.scala +++ b/core/src/main/scala/org/graphframes/pattern/patterns.scala @@ -21,7 +21,7 @@ import org.graphframes.GraphFramesUnreachableException import org.graphframes.InvalidParseException import scala.collection.mutable -import scala.util.parsing.combinator._ +import scala.util.parsing.combinator.* /** * Parser for graph patterns for motif finding. Copied from GraphFrames with minor modification. diff --git a/core/src/main/scala/org/graphframes/propertygraph/property/EdgePropertyGroup.scala b/core/src/main/scala/org/graphframes/propertygraph/property/EdgePropertyGroup.scala index 248b35ffe..4be1b338d 100644 --- a/core/src/main/scala/org/graphframes/propertygraph/property/EdgePropertyGroup.scala +++ b/core/src/main/scala/org/graphframes/propertygraph/property/EdgePropertyGroup.scala @@ -6,7 +6,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.concat import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.sha2 -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.* import org.graphframes.GraphFrame import org.graphframes.InvalidPropertyGroupException diff --git a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala index ff541adce..b0af81ffa 100644 --- a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.graphframes.graphx.Graph import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StringType diff --git a/core/src/test/scala/org/graphframes/TestUtils.scala b/core/src/test/scala/org/graphframes/TestUtils.scala index 8e4f9e932..d13293d42 100644 --- a/core/src/test/scala/org/graphframes/TestUtils.scala +++ b/core/src/test/scala/org/graphframes/TestUtils.scala @@ -3,7 +3,7 @@ package org.graphframes import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.StructType -import org.graphframes.GraphFrame._ +import org.graphframes.GraphFrame.* object TestUtils { diff --git a/core/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala b/core/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala index ff1999977..71fb15d60 100644 --- a/core/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala +++ b/core/src/test/scala/org/graphframes/examples/BeliefPropagationSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.graphframes.GraphFrameTestSparkContext import org.graphframes.GraphFramesUnreachableException import org.graphframes.SparkFunSuite -import org.graphframes.examples.BeliefPropagation._ +import org.graphframes.examples.BeliefPropagation.* import org.graphframes.examples.Graphs.gridIsingModel class BeliefPropagationSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala index dc465ed76..60444449d 100644 --- a/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala +++ b/core/src/test/scala/org/graphframes/ldbc/TestLDBCCases.scala @@ -16,7 +16,7 @@ import org.graphframes.SparkFunSuite import org.graphframes.examples.LDBCUtils import java.io.File -import java.nio.file._ +import java.nio.file.* import java.util.Properties class TestLDBCCases extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/core/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala b/core/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala index 5cc349bf0..86fbaae6f 100644 --- a/core/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/AggregateMessagesSuite.scala @@ -18,8 +18,8 @@ package org.graphframes.lib import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions.* +import org.apache.spark.sql.types.* import org.graphframes.GraphFrame import org.graphframes.GraphFrameTestSparkContext import org.graphframes.GraphFramesUnreachableException diff --git a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index dbe097cab..8bec4b061 100644 --- a/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.types.LongType import org.apache.spark.storage.StorageLevel -import org.graphframes.GraphFrame._ -import org.graphframes._ +import org.graphframes.* +import org.graphframes.GraphFrame.* import org.graphframes.examples.Graphs import scala.reflect.ClassTag diff --git a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala index f8fa28f4d..64d4aee0f 100644 --- a/core/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -17,9 +17,9 @@ package org.graphframes.lib -import org.apache.spark.sql.functions._ -import org.graphframes._ -import org.scalactic.Tolerance._ +import org.apache.spark.sql.functions.* +import org.graphframes.* +import org.scalactic.Tolerance.* class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/core/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala b/core/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala index 4aef63842..2ec345f12 100644 --- a/core/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala +++ b/core/src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes +import org.graphframes.* import org.graphframes.GraphFrame.quote -import org.graphframes._ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext { diff --git a/core/src/test/scala/org/graphframes/propertygraph/PropertyGraphFrameTest.scala b/core/src/test/scala/org/graphframes/propertygraph/PropertyGraphFrameTest.scala index 748973b9c..88540ba93 100644 --- a/core/src/test/scala/org/graphframes/propertygraph/PropertyGraphFrameTest.scala +++ b/core/src/test/scala/org/graphframes/propertygraph/PropertyGraphFrameTest.scala @@ -1,7 +1,7 @@ package org.graphframes.propertygraph import org.apache.spark.sql.Column -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.* import org.graphframes.GraphFrame import org.graphframes.GraphFrameTestSparkContext import org.graphframes.SparkFunSuite diff --git a/docs/src/01-about/01-index.md b/docs/src/01-about/01-index.md index 6f50378d9..8bb083657 100644 --- a/docs/src/01-about/01-index.md +++ b/docs/src/01-about/01-index.md @@ -8,11 +8,59 @@ GraphFrames represent graphs: vertices (e.g., users) and edges (e.g., relationsh GraphFrames also provide powerful tools for running queries and standard graph algorithms. With GraphFrames, you can easily search for patterns within graphs, find important vertices, and more. Refer to the [User Guide](/04-user-guide/01-creating-graphframes.md) for a full list of queries and algorithms. +# Use-cases of GraphFrames + +## Ranking in search systems + +`PageRank` is a fundamental algorithm originally developed by Google for ranking web pages in search results. It works by measuring the importance of nodes in a graph based on the link structure, where links from highly-ranked pages contribute more to the rank of target pages. This principle can be extended to ranking documents in search systems, where documents are treated as nodes and hyperlinks or semantic relationships as edges. + +GraphFrames provides a fully distributed Spark-based implementation of the `PageRank` algorithm, enabling efficient computation of document rankings at scale. This implementation leverages the power of Apache Spark's distributed computing model, allowing organizations to analyze large-scale document networks without sacrificing performance. + +## Graph Clustering + +GraphClustering algorithms like `Label Propagation` and `Power Iteration Clustering` are built into GraphFrames and provide efficient ways to perform unsupervised clustering on large graphs. These algorithms leverage the distributed nature of Apache Spark to scale clustering operations across massive datasets while maintaining accuracy and performance. + +Label Propagation is a fast and efficient method for detecting communities in large graphs. It works by iteratively updating node labels based on the majority label of neighboring nodes, eventually leading to clusters where nodes within the same community share similar labels. This algorithm is particularly effective for identifying overlapping communities and is well-suited for real-time applications due to its simplicity and low computational overhead. + +Power Iteration Clustering (PIC) is another powerful clustering algorithm included in GraphFrames. PIC uses the eigenvectors of the graph's normalized adjacency matrix to assign nodes to clusters. It is especially effective for finding well-separated clusters and can handle large-scale graphs efficiently through Spark's distributed computing capabilities. The algorithm is based on the principle that nodes belonging to the same cluster tend to have similar values in the dominant eigenvector, making it a robust choice for various graph clustering tasks. + +## Anti-fraud and compliance applications + +GraphFrames provides powerful tools for analyzing complex networks, offering distributed implementations that scale seamlessly with Apache Spark. Here are two notable algorithms usable for anti-fraud and compliance analysis. + +### ShortestPaths Algorithm + +The `ShortestPaths` algorithm can be used for identifying the shortest paths within a graph. This is particularly valuable for analyzing financial networks to find the minimum distances to known suspicious nodes. Such insights can be applied to enhance compliance scoring and detect suspicious activities with greater efficiency. In GraphFrames `ShortestPaths` algorithm is implemented in a vertic-centric Pregel framework that effectively distributes the work across the whole Apache Spark cluster. + +### Cycles Detection with Rocha-Thatte Algorithm + +GraphFrames includes an implementation of the [Rocha-Thatte cycles detection algorithm](https://en.wikipedia.org/wiki/Rocha%E2%80%93Thatte_cycle_detection_algorithm). This algorithm is designed to find all cycles in large graphs, making it an essential tool for uncovering suspicious activities like circular money flows. By efficiently detecting cycles, it enables analysts to better understand the structure of data and identify potential fraud. + +### Motifs finding + +Motifs finding is a powerful technique for identifying recurring patterns within graphs, which proves especially useful in detecting suspicious transactions and actions in financial networks. By analyzing the structural patterns of interactions between entities such as accounts, transactions, and merchants, motif finding can reveal common fraud schemes like money laundering, identity theft, or collusion among bad actors. For instance, specific motifs might indicate unusual transaction sequences that are typical of pump-and-dump schemes or layered fraudulent transfers. This capability allows financial institutions to proactively identify and investigate potentially risky behaviors before they escalate into significant losses. + +GraphFrames provides a `find` API to find motifs at scale with a fully distributed algorithm powered by Apache Spark. + +## Data deduplication and identity resolution + +GraphFrames provides a highly efficient distributed algorithm called ["big-star small-star"](https://dl.acm.org/doi/pdf/10.1145/2670979.2670997) for finding connected components in large graphs. This algorithm is particularly useful for data deduplication and fingerprinting of massive datasets containing billions of rows. By constructing an interaction graph where each entity is represented as a vertex and relationships between entities are represented as edges, the connected components algorithm can group together rows that refer to the same real-world entity, even if they have different IDs across various systems or sessions. + +For example, consider a scenario where a user has multiple accounts across different platforms or systems, each with its own unique identifier. By creating a graph where vertices represent these accounts and edges represent known relationships (such as shared email addresses, IP addresses, or transaction histories), the connected components algorithm can identify all the vertices that belong to the same user. These vertices are then grouped together and assigned a unified ID, effectively deduplicating the data while preserving the integrity of the underlying entity relationships. + +This approach is especially powerful in scenarios involving customer data management, fraud detection, and identity resolution, where entities may appear under different identifiers across various data sources. The distributed nature of the "big-star small-star" algorithm ensures that such operations can be performed efficiently at scale, making it possible to process and deduplicate massive datasets in a reasonable amount of time. + +## Custom graph algorithms + +GraphFrames provides two powerful APIs: `AggregateMessages` and `Pregel` that allow users to write and run custom algorithms using a scalable and distributed vertex-centric approach on top of Apache Spark. These APIs enable developers to implement complex graph algorithms efficiently by leveraging Spark's distributed computing capabilities. The `AggregateMessages` API facilitates message passing between vertices by aggregating messages from neighboring vertices, making it ideal for implementing iterative graph algorithms. Meanwhile, the `Pregel` API offers a more traditional vertex-centric programming model, allowing users to define custom computations that run in iterations until convergence. Together, these APIs provide the flexibility needed to build sophisticated graph analytics solutions that can handle large-scale data processing requirements while maintaining the performance and reliability of the Apache Spark ecosystem. + # Downloading Get GraphFrames from the [Maven Central](https://central.sonatype.com/namespace/io.graphframes). GraphFrames depends on Apache Spark, which is available for download from the [Apache Spark website](http://spark.apache.org). -GraphFrames should be compatible with any platform that runs Spark. Refer to the [Apache Spark documentation](http://spark.apache.org/docs/latest) for more information. +GraphFrames should be compatible with any platform that runs the open-source Spark. Refer to the [Apache Spark documentation](http://spark.apache.org/docs/latest) for more information. + +**WARNING:** *Some vendors are maintain their own internal forks of the Apache Spark that may be not fully compatible with an OSS version. While GraphFrames project is trying to rely only on public and stable APIs of the Apache Spark, some incompatibility is still possible. Fell free to open an issue in case you are facing problems in modified Spark environments like Databricks Platform.* GraphFrames is compatible with Spark 3.4+. However, later versions of Spark include major improvements to DataFrames, so GraphFrames may be more efficient when running on more recent Spark versions. diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Graph.scala index b57b20f91..5eb5abb73 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Graph.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx -import org.apache.spark.graphframes.graphx.impl._ +import org.apache.spark.graphframes.graphx.impl.* import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/VertexRDD.scala index 6a0ad3481..263051555 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/VertexRDD.scala @@ -17,12 +17,12 @@ package org.apache.spark.graphframes.graphx -import org.apache.spark._ +import org.apache.spark.* import org.apache.spark.graphframes.graphx.impl.RoutingTablePartition import org.apache.spark.graphframes.graphx.impl.ShippableVertexPartition import org.apache.spark.graphframes.graphx.impl.VertexAttributeBlock import org.apache.spark.graphframes.graphx.impl.VertexRDDImpl -import org.apache.spark.rdd._ +import org.apache.spark.rdd.* import org.apache.spark.storage.StorageLevel import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartition.scala index 34ede5f35..113ba9e0a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartition.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionBuilder.scala index d7f1d1cef..10c3a8c9c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.PrimitiveVector import org.apache.spark.util.collection.SortDataFormat diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgeRDDImpl.scala index 214d2645b..b9a8f04e7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/EdgeRDDImpl.scala @@ -20,7 +20,7 @@ package org.apache.spark.graphframes.graphx.impl import org.apache.spark.HashPartitioner import org.apache.spark.OneToOneDependency import org.apache.spark.Partitioner -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala index 856cd112b..64b588c24 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphframes.graphx.impl import org.apache.spark.HashPartitioner -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ReplicatedVertexView.scala index bb237543b..635d5aebd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ReplicatedVertexView.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/RoutingTablePartition.scala index ac79dee73..cf480ee34 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/RoutingTablePartition.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet import org.apache.spark.util.collection.PrimitiveVector diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ShippableVertexPartition.scala index a9f0fa5c7..5754d33d7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/ShippableVertexPartition.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet import org.apache.spark.util.collection.PrimitiveVector diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartition.scala index 80b022c4c..c37539a2e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartition.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.util.collection.BitSet import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBase.scala index 83c5e48b1..3eae350f7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBaseOps.scala index 1c9927cb8..ab95986bf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionBaseOps.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.internal.Logging import org.apache.spark.util.collection.BitSet diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexRDDImpl.scala index db188e47a..e13872448 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/VertexRDDImpl.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphframes.graphx.impl -import org.apache.spark._ -import org.apache.spark.graphframes.graphx._ -import org.apache.spark.rdd._ +import org.apache.spark.* +import org.apache.spark.graphframes.graphx.* +import org.apache.spark.rdd.* import org.apache.spark.storage.StorageLevel import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponents.scala index 21f40a197..40698153b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponents.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import scala.annotation.nowarn import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala index 351a8a068..c807aeb10 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagation.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import scala.annotation.nowarn import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/PageRank.scala index 2b7815e37..38dee1d0d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/PageRank.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib -import breeze.linalg.{Vector => BV} -import org.apache.spark.graphframes.graphx._ +import breeze.linalg.Vector as BV +import org.apache.spark.graphframes.graphx.* import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vectors diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlus.scala index e7a4979ad..5ea04010f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlus.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.ml.linalg.BLAS -import org.apache.spark.rdd._ +import org.apache.spark.rdd.* import scala.util.Random diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ShortestPaths.scala index 58d61be9a..411fa4f10 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/ShortestPaths.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import scala.annotation.nowarn import scala.collection.Map diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponents.scala index 227392c9b..665fa17fe 100755 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponents.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import scala.annotation.nowarn import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/TriangleCount.scala index 5ff98bfe6..8db3d5ff5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/lib/TriangleCount.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphframes.graphx.lib -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import scala.annotation.nowarn import scala.reflect.ClassTag diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/GraphGenerators.scala index 1c1f1a4ad..e1f104c21 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/GraphGenerators.scala @@ -17,15 +17,15 @@ package org.apache.spark.graphframes.graphx.util -import org.apache.spark._ -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.* +import org.apache.spark.graphframes.graphx.* import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import scala.annotation.tailrec import scala.collection.mutable import scala.reflect.ClassTag -import scala.util._ +import scala.util.* /** A collection of graph generating functions. */ object GraphGenerators extends Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 588e667fa..e842b1ae1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphframes.graphx.util.collection import org.apache.spark.util.collection.OpenHashSet -import scala.reflect._ +import scala.reflect.* /** * A fast hash map implementation for primitive, non-null keys. This hash map supports insertions diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphOpsSuite.scala index 70c08e47c..9a367f98a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphOpsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphframes.graphx import org.apache.spark.SparkContext -import org.apache.spark.graphframes.graphx.Graph._ +import org.apache.spark.graphframes.graphx.Graph.* import org.scalatest.funsuite.AnyFunSuite class GraphOpsSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphSuite.scala index eb550a1f7..42cfd5fff 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/GraphSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.graphframes.graphx import org.apache.spark.SparkContext -import org.apache.spark.graphframes.graphx.Graph._ -import org.apache.spark.graphframes.graphx.PartitionStrategy._ -import org.apache.spark.rdd._ +import org.apache.spark.graphframes.graphx.Graph.* +import org.apache.spark.graphframes.graphx.PartitionStrategy.* +import org.apache.spark.rdd.* import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.scalatest.funsuite.AnyFunSuite diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionSuite.scala index 04da8e25c..a5aca2221 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/EdgePartitionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphframes.graphx.impl import org.apache.spark.SparkConf -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.scalatest.funsuite.AnyFunSuite diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionSuite.scala index 4b3a8fd73..d0e5eced3 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/impl/VertexPartitionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphframes.graphx.impl import org.apache.spark.SparkConf -import org.apache.spark.graphframes.graphx._ +import org.apache.spark.graphframes.graphx.* import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.scalatest.funsuite.AnyFunSuite diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponentsSuite.scala index ce567ba14..a2b5d2a33 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ConnectedComponentsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.apache.spark.graphframes.graphx.util.GraphGenerators -import org.apache.spark.rdd._ +import org.apache.spark.rdd.* import org.scalatest.funsuite.AnyFunSuite class ConnectedComponentsSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagationSuite.scala index 84f276443..fac8ecfc6 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/LabelPropagationSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.scalatest.funsuite.AnyFunSuite class LabelPropagationSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/PageRankSuite.scala index 116bc1003..7e00241a9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/PageRankSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.apache.spark.graphframes.graphx.util.GraphGenerators import org.scalatest.funsuite.AnyFunSuite diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlusSuite.scala index e0573edc6..58e20b381 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/SVDPlusPlusSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.scalatest.funsuite.AnyFunSuite class SVDPlusPlusSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ShortestPathsSuite.scala index 5a1090142..c03d909d2 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/ShortestPathsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.scalatest.funsuite.AnyFunSuite class ShortestPathsSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponentsSuite.scala index 12530e216..7a5b3b6f7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext -import org.apache.spark.graphframes.graphx._ import org.scalatest.funsuite.AnyFunSuite class StronglyConnectedComponentsSuite extends AnyFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/TriangleCountSuite.scala index 5de15363e..2e6c2de2b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphframes/graphx/lib/TriangleCountSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphframes.graphx.lib +import org.apache.spark.graphframes.graphx.* import org.apache.spark.graphframes.graphx.LocalSparkContext import org.apache.spark.graphframes.graphx.PartitionStrategy.RandomVertexCut -import org.apache.spark.graphframes.graphx._ import org.scalatest.funsuite.AnyFunSuite class TriangleCountSuite extends AnyFunSuite with LocalSparkContext { From 5695170464066b3fe95c71a02ebb5c8d8fa7cdb2 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 26 Sep 2025 15:40:47 +0200 Subject: [PATCH 03/17] more docs --- docs/src/01-about/01-index.md | 55 ++++++++++++++++++------- docs/src/04-user-guide/05-traversals.md | 4 +- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/docs/src/01-about/01-index.md b/docs/src/01-about/01-index.md index 8bb083657..663cda7c7 100644 --- a/docs/src/01-about/01-index.md +++ b/docs/src/01-about/01-index.md @@ -1,15 +1,42 @@ # About -GraphFrames is a package for Apache Spark which provides DataFrame-based Graphs. It provides high-level APIs in Scala, Java, and Python. It aims to provide both the functionality of GraphX and extended functionality taking advantage of Spark DataFrames. This extended functionality includes motif finding, DataFrame-based serialization, and highly expressive graph queries. +GraphFrames is a package for Apache Spark which provides DataFrame-based Graphs. It provides high-level APIs in Scala, Java, and Python. It aims to provide both the functionality of GraphX and extended functionality taking advantage of Spark DataFrames. This extended functionality includes motif finding, DataFrame-based serialization, and highly expressive graph queries. # What are GraphFrames? -GraphFrames represent graphs: vertices (e.g., users) and edges (e.g., relationships between users). If you are familiar with [GraphX](http://spark.apache.org/docs/latest/graphx-programming-guide.html), then GraphFrames will be easy to learn. The key difference is that GraphFrames are based upon [Spark DataFrames](http://spark.apache.org/docs/latest/sql-programming-guide.html), rather than [RDDs](http://spark.apache.org/docs/latest/programming-guide.html#resilient-distributed-datasets-rdds). +GraphFrames represent graphs: vertices (e.g., users) and edges (e.g., relationships between users) in the form of Apache Spark DataFrame objects. On top of this, GraphFrames provides not only basic APIs like `filterVertices` or `outDegrees`, but also a set of powerful APIs for graph algorithms and complex graph processing. -GraphFrames also provide powerful tools for running queries and standard graph algorithms. With GraphFrames, you can easily search for patterns within graphs, find important vertices, and more. Refer to the [User Guide](/04-user-guide/01-creating-graphframes.md) for a full list of queries and algorithms. +## GraphFrames vs GraphX + +GraphFrames provides most of the algorithm and routines in two ways: + +- Native DataFrame based implementation; +- Wrapper over GraphX implementation. + +**NOTE:** GraphX is deprecated in the upstream Apache Spark and is not maintained anymore. GraphFrames project come with it's own fork of GraphX: `org.apache.spark.graphframes.graphx`. While we are trying do not make any breaking changes in GraphFrames' GraphX, it is still considered as a part of the internal API. The best way to use it is via GraphFrame-GraphX conversion utils, instead of directly manipulate GraphX structures. + +### Graph Representation + +- GraphX represents graphs by the pair of `RDD`: `VertexRDD` and `EdgeRDD`. +- GraphFrames represent graphs by the pair of `DataFrame`: `vertices` and `edges`. + +While `RDD` may provide slightly more flexible API and, in theory, processing of RDDs may be faster, they requires much more memory to process them. For example, `VertexRDD[Unit]` that contains de-facto only `Long` vertex IDs will require much more memory to store and process compared to the `DataFrame` of vertices with a single `Long` column. The reason is serialization of `RDD` are done by serializing the underlying JVM objects, but serialization of data in `DataFrame` rely on the `Thungsten` with it's own serialization format. On bechmarks, memory overhead of serializing Java objects may be up to five times, while the compute overhead of creating JVM objects from thungsten format is less than 10-15%. + +### Optimizations + +- GraphX rely on it's own partitioning strategy and building and maintaining partitions index. +- GraphFrames rely on the Apache Spark Catalyst optimizer and Adaptive Query Execution. + +In most of the cases that include real-world complex tranformations, especially on really big data, Catalyst + AQE will provide better results compared to manual index of partitions. + +### If DataFrames are better, why GraphFrames still provides conversion methods? + +Our [benhmarks](03-benchmarks.md) shows that on small and medium graphs GraphX may be better choice. With GraphX users can sacrifice memory consumption if favor of better running time without query optimization overhead. That may be suitable, for example, for Spark Structured Streaming scenarios. # Use-cases of GraphFrames +Refer to the [User Guide](/04-user-guide/01-creating-graphframes.md) for a full list of queries and algorithms. + ## Ranking in search systems `PageRank` is a fundamental algorithm originally developed by Google for ranking web pages in search results. It works by measuring the importance of nodes in a graph based on the link structure, where links from highly-ranked pages contribute more to the rank of target pages. This principle can be extended to ranking documents in search systems, where documents are treated as nodes and hyperlinks or semantic relationships as edges. @@ -60,7 +87,7 @@ Get GraphFrames from the [Maven Central](https://central.sonatype.com/namespace/ GraphFrames should be compatible with any platform that runs the open-source Spark. Refer to the [Apache Spark documentation](http://spark.apache.org/docs/latest) for more information. -**WARNING:** *Some vendors are maintain their own internal forks of the Apache Spark that may be not fully compatible with an OSS version. While GraphFrames project is trying to rely only on public and stable APIs of the Apache Spark, some incompatibility is still possible. Fell free to open an issue in case you are facing problems in modified Spark environments like Databricks Platform.* +**WARNING:** Some vendors are maintain their own internal forks of the Apache Spark that may be not fully compatible with an OSS version. While GraphFrames project is trying to rely only on public and stable APIs of the Apache Spark, some incompatibility is still possible. Fell free to open an issue in case you are facing problems in modified Spark environments like Databricks Platform. GraphFrames is compatible with Spark 3.4+. However, later versions of Spark include major improvements to DataFrames, so GraphFrames may be more efficient when running on more recent Spark versions. @@ -74,20 +101,20 @@ See the [Apache Spark User Guide](http://spark.apache.org/docs/latest/) for more **User Guides:** -* [Quick Start](/02-quick-start/02-quick-start.md): a quick introduction to the GraphFrames API; start here! -* [GraphFrames User Guide](/04-user-guide/01-creating-graphframes.md): detailed overview of GraphFrames +- [Quick Start](/02-quick-start/02-quick-start.md): a quick introduction to the GraphFrames API; start here! +- [GraphFrames User Guide](/04-user-guide/01-creating-graphframes.md): detailed overview of GraphFrames in all supported languages (Scala, Java, Python) -* [Motif Finding Tutorial](/03-tutorials/02-motif-tutorial.md): learn to perform pattern recognition with GraphFrames using a technique called network motif finding over the knowledge graph for the `stackexchange.com` subdomain [data dump](https://archive.org/details/stackexchange) -* [GraphFrames Configurations](/04-user-guide/13-configurations.md): detailed information about GraphFrames configurations, their descriptions, and usage examples +- [Motif Finding Tutorial](/03-tutorials/02-motif-tutorial.md): learn to perform pattern recognition with GraphFrames using a technique called network motif finding over the knowledge graph for the `stackexchange.com` subdomain [data dump](https://archive.org/details/stackexchange) +- [GraphFrames Configurations](/04-user-guide/13-configurations.md): detailed information about GraphFrames configurations, their descriptions, and usage examples **Community Forums:** -* [GraphFrames Mailing List](https://groups.google.com/g/graphframes/): ask questions about GraphFrames here -* [#graphframes Discord Channel on GraphGeeks](https://discord.com/channels/1162999022819225631/1326257052368113674) +- [GraphFrames Mailing List](https://groups.google.com/g/graphframes/): ask questions about GraphFrames here +- [#graphframes Discord Channel on GraphGeeks](https://discord.com/channels/1162999022819225631/1326257052368113674) **External Resources:** -* [Apache Spark Homepage](http://spark.apache.org) -* [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) -* [Apache Spark Mailing Lists](http://spark.apache.org/mailing-lists.html) -* [GraphFrames on Stack Overflow](https://stackoverflow.com/questions/tagged/graphframes) +- [Apache Spark Homepage](http://spark.apache.org) +- [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) +- [Apache Spark Mailing Lists](http://spark.apache.org/mailing-lists.html) +- [GraphFrames on Stack Overflow](https://stackoverflow.com/questions/tagged/graphframes) diff --git a/docs/src/04-user-guide/05-traversals.md b/docs/src/04-user-guide/05-traversals.md index 79fd19f50..5f0cd1d8c 100644 --- a/docs/src/04-user-guide/05-traversals.md +++ b/docs/src/04-user-guide/05-traversals.md @@ -222,5 +222,5 @@ res.show(false) // +----+--------------+ ``` -**WARNING:** This algorithm returns all the cycles, and users should handle deduplication of [1, 2, 1] and [2, 1, 2] ( -that is the same cycle) \ No newline at end of file +**WARNING:** This algorithm returns all the cycles, and users should handle deduplication of \[1, 2, 1\] and \[2, 1, 2\] ( +that is the same cycle)! From eab72fd95017310523b443e409bb445197484e1c Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 27 Sep 2025 07:18:56 +0200 Subject: [PATCH 04/17] update formatting and fix java8 --- .../examples/TriangleCountExample.java | 4 ++-- .../apache/spark/sql/graphframes/SparkShims.scala | 6 +++--- .../graphframes/examples/BeliefPropagation.scala | 15 +++++++-------- .../lib/ParallelPersonalizedPageRank.scala | 4 ++-- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/graphframes/examples/TriangleCountExample.java b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java index ff8fc2089..e0e2a88d3 100644 --- a/core/src/main/java/org/graphframes/examples/TriangleCountExample.java +++ b/core/src/main/java/org/graphframes/examples/TriangleCountExample.java @@ -76,7 +76,7 @@ public static void main(String[] args) { .persist(StorageLevel.MEMORY_AND_DISK_SER()); System.out.println("Vertices loaded: " + vertices.count()); - var start = System.currentTimeMillis(); + long start = System.currentTimeMillis(); GraphFrame graph = GraphFrame.apply(vertices, edges); TriangleCount counter = graph.triangleCount(); Dataset triangles = counter.run(); @@ -84,7 +84,7 @@ public static void main(String[] args) { triangles.show(20, false); long triangleCount = triangles.select(functions.sum("count")).first().getLong(0); System.out.println("Found triangles: " + triangleCount); - var end = System.currentTimeMillis(); + long end = System.currentTimeMillis(); System.out.println("Total running time in seconds: " + (end - start) / 1000.0); } } diff --git a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala index 4ea8e072b..fea8e553a 100644 --- a/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala +++ b/core/src/main/scala-spark-4/org/apache/spark/sql/graphframes/SparkShims.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.classic.ClassicConversions._ +import org.apache.spark.sql.classic.ClassicConversions.* +import org.apache.spark.sql.classic.DataFrame as ClassicDataFrame import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.classic.ExpressionUtils -import org.apache.spark.sql.classic.{DataFrame => ClassicDataFrame} -import org.apache.spark.sql.classic.{SparkSession => ClassicSparkSession} +import org.apache.spark.sql.classic.SparkSession as ClassicSparkSession object SparkShims { diff --git a/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala b/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala index 13bdb4e7c..5979a7a0a 100644 --- a/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala +++ b/core/src/main/scala/org/graphframes/examples/BeliefPropagation.scala @@ -17,9 +17,7 @@ package org.graphframes.examples -import org.apache.spark.graphframes.graphx.Edge as GXEdge -import org.apache.spark.graphframes.graphx.Graph -import org.apache.spark.graphframes.graphx.VertexRDD +import org.apache.spark.graphframes.graphx import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession @@ -146,22 +144,22 @@ object BeliefPropagation { val vColsMap = colorG.vertexColumnMap val eColsMap = colorG.edgeColumnMap // Convert vertex attributes to nice case classes. - val gx1: Graph[VertexAttr, Row] = gx0.mapVertices { case (_, attr) => + val gx1: graphx.Graph[VertexAttr, Row] = gx0.mapVertices { case (_, attr) => // Initialize belief at 0.0 VertexAttr(attr.getDouble(vColsMap("a")), 0.0, attr.getInt(vColsMap("color"))) } // Convert edge attributes to nice case classes. - val extractEdgeAttr: (GXEdge[Row] => EdgeAttr) = { e => + val extractEdgeAttr: (graphx.Edge[Row] => EdgeAttr) = { e => EdgeAttr(e.attr.getDouble(eColsMap("b"))) } - var gx: Graph[VertexAttr, EdgeAttr] = gx1.mapEdges(extractEdgeAttr) + var gx: graphx.Graph[VertexAttr, EdgeAttr] = gx1.mapEdges(extractEdgeAttr) // Run BP for numIter iterations. for (_ <- Range(0, numIter)) { // For each color, have that color receive messages from neighbors. for (color <- Range(0, numColors)) { // Send messages to vertices of the current color. - val msgs: VertexRDD[Double] = gx.aggregateMessages( + val msgs: graphx.VertexRDD[Double] = gx.aggregateMessages( ctx => // Can send to source or destination since edges are treated as undirected. if (ctx.dstAttr.color == color) { @@ -188,7 +186,8 @@ object BeliefPropagation { } // Convert back to GraphFrame with a new column "belief" for vertices DataFrame. - val gxFinal: Graph[Double, Unit] = gx.mapVertices((_, attr) => attr.belief).mapEdges(_ => ()) + val gxFinal: graphx.Graph[Double, Unit] = + gx.mapVertices((_, attr) => attr.belief).mapEdges(_ => ()) GraphFrame.fromGraphX(colorG, gxFinal, vertexNames = Seq("belief")) } diff --git a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 072d5b93d..32b76c47e 100644 --- a/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/core/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -17,7 +17,7 @@ package org.graphframes.lib -import org.apache.spark.graphframes.graphx.lib as graphxlib +import org.apache.spark.graphframes.graphx import org.graphframes.GraphFrame import org.graphframes.Logging import org.graphframes.WithMaxIter @@ -113,7 +113,7 @@ private object ParallelPersonalizedPageRank { resetProb: Double, sourceIds: Array[Any]): GraphFrame = { val longSrcIds = sourceIds.map(GraphXConversions.integralId(graph, _)) - val gx = graphxlib.PageRank.runParallelPersonalizedPageRank( + val gx = graphx.lib.PageRank.runParallelPersonalizedPageRank( graph.cachedTopologyGraphX, maxIter, resetProb, From 2033da624426ff4bb0bbc6dc6a3fa468f8d3b2e1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 28 Sep 2025 16:58:55 +0200 Subject: [PATCH 05/17] WIP --- buf.yaml | 2 +- connect/src/main/protobuf/graphframes.proto | 87 ++++-- .../graphframes/GraphFramesConnectUtils.scala | 277 +++++++++++++++--- docs/src/01-about/01-index.md | 16 +- .../{lib => classic}/aggregate_messages.py | 0 python/graphframes/classic/graphframe.py | 93 ++++-- python/graphframes/{lib => classic}/pregel.py | 0 python/graphframes/classic/utils.py | 16 + .../graphframes/connect/graphframe_client.py | 244 +++++++++++++-- .../connect/proto/graphframes_pb2.py | 84 +++--- .../connect/proto/graphframes_pb2.pyi | 169 ++++++++++- python/graphframes/connect/utils.py | 30 +- python/graphframes/graphframe.py | 84 +++++- python/graphframes/lib/__init__.py | 4 - 14 files changed, 907 insertions(+), 199 deletions(-) rename python/graphframes/{lib => classic}/aggregate_messages.py (100%) rename python/graphframes/{lib => classic}/pregel.py (100%) create mode 100644 python/graphframes/classic/utils.py delete mode 100644 python/graphframes/lib/__init__.py diff --git a/buf.yaml b/buf.yaml index e0cdbd729..95ef5c72f 100644 --- a/buf.yaml +++ b/buf.yaml @@ -1,3 +1,3 @@ version: v2 modules: - - path: graphframes-connect/src/main/protobuf \ No newline at end of file + - path: connect/src/main/protobuf diff --git a/connect/src/main/protobuf/graphframes.proto b/connect/src/main/protobuf/graphframes.proto index 5b9647ee6..c22223941 100644 --- a/connect/src/main/protobuf/graphframes.proto +++ b/connect/src/main/protobuf/graphframes.proto @@ -5,33 +5,55 @@ package org.graphframes.connect.proto; option java_multiple_files = true; option java_package = "org.graphframes.connect.proto"; option java_generate_equals_and_hash = true; -option optimize_for=SPEED; +option optimize_for = SPEED; +// GraphFramesAPI represents the core message type for GraphFrames operations +// containing graph data and the specific graph algorithm to be executed message GraphFramesAPI { + // Serialized vertex DataFrame containing node information bytes vertices = 1; + // Serialized edge DataFrame containing relationship information bytes edges = 2; + // Specifies which graph algorithm operation to perform oneof method { AggregateMessages aggregate_messages = 3; BFS bfs = 4; ConnectedComponents connected_components = 5; DropIsolatedVertices drop_isolated_vertices = 6; - FilterEdges filter_edges = 7; - FilterVertices filter_vertices = 8; - Find find = 9; - LabelPropagation label_propagation = 10; - PageRank page_rank = 11; - ParallelPersonalizedPageRank parallel_personalized_page_rank = 12; - PowerIterationClustering power_iteration_clustering = 13; - Pregel pregel = 14; - ShortestPaths shortest_paths = 15; - StronglyConnectedComponents strongly_connected_components = 16; - SVDPlusPlus svd_plus_plus = 17; - TriangleCount triangle_count = 18; - Triplets triplets = 19; + DetectingCycles detecting_cycles = 7; + FilterEdges filter_edges = 8; + FilterVertices filter_vertices = 9; + Find find = 10; + LabelPropagation label_propagation = 11; + PageRank page_rank = 12; + ParallelPersonalizedPageRank parallel_personalized_page_rank = 13; + PowerIterationClustering power_iteration_clustering = 14; + Pregel pregel = 15; + ShortestPaths shortest_paths = 16; + StronglyConnectedComponents strongly_connected_components = 17; + SVDPlusPlus svd_plus_plus = 18; + TriangleCount triangle_count = 19; + Triplets triplets = 20; } } +// Mapping follows PySpark Storage Levels! +// (not Scala-Spark Storage Levels) +message StorageLevel { + oneof storage_level { + bool disk_only = 1; + bool disk_only_2 = 2; + bool disk_only_3 = 3; + bool memory_and_disk = 4; + bool memory_and_disk_2 = 5; + bool memory_and_disk_deser = 6; + bool memory_only = 7; + bool memory_only_2 = 8; + } +} + +// String expression or serialized column message ColumnOrExpression { oneof col_or_expr { bytes col = 1; @@ -39,6 +61,7 @@ message ColumnOrExpression { } } +// Connect supports only string or long-like IDs message StringOrLongID { oneof id { int64 long_id = 1; @@ -47,9 +70,10 @@ message StringOrLongID { } message AggregateMessages { - ColumnOrExpression agg_col = 1; - optional ColumnOrExpression send_to_src = 2; - optional ColumnOrExpression send_to_dst = 3; + repeated ColumnOrExpression agg_col = 1; + repeated ColumnOrExpression send_to_src = 2; + repeated ColumnOrExpression send_to_dst = 3; + optional StorageLevel storage_level = 4; } message BFS { @@ -64,6 +88,15 @@ message ConnectedComponents { int32 checkpoint_interval = 2; int32 broadcast_threshold = 3; bool use_labels_as_components = 4; + bool use_local_checkpoints = 5; + int32 max_iter = 6; + optional StorageLevel storage_level = 7; +} + +message DetectingCycles { + bool use_local_checkpoints = 1; + int32 checkpoint_interval = 2; + optional StorageLevel storage_level = 3; } message DropIsolatedVertices {} @@ -81,7 +114,11 @@ message Find { } message LabelPropagation { - int32 max_iter = 1; + string algorithm = 1; + int32 max_iter = 2; + bool use_local_checkpoints = 3; + int32 checkpoint_interval = 4; + optional StorageLevel storage_level = 5; } message PageRank { @@ -113,10 +150,20 @@ message Pregel { ColumnOrExpression additional_col_initial = 7; ColumnOrExpression additional_col_upd = 8; optional bool early_stopping = 9; + bool use_local_checkpoints = 10; + optional StorageLevel storage_level = 11; + optional bool stop_if_all_non_active = 12; + optional ColumnOrExpression initial_active_expr = 13; + optional ColumnOrExpression update_active_expr = 14; + optional bool skip_messages_from_non_active = 15; } message ShortestPaths { repeated StringOrLongID landmarks = 1; + string algorithm = 2; + bool use_local_checkpoints = 3; + int32 checkpoint_interval = 4; + optional StorageLevel storage_level = 5; } message StronglyConnectedComponents { @@ -134,6 +181,8 @@ message SVDPlusPlus { double gamma7 = 8; } -message TriangleCount {} +message TriangleCount { + optional StorageLevel storage_level = 1; +} message Triplets {} diff --git a/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index 5ce4d967d..ea63ae281 100644 --- a/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -1,5 +1,5 @@ -// Because Dataset.ofRows is private[sql] we are forced to use spark package -// Same about Column helper object. +// Because Dataset.ofRows is private[sql], we are forced to use spark package; +// Same about a Column helper object. package org.apache.spark.sql.graphframes import com.google.protobuf.ByteString @@ -8,43 +8,103 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.GraphFramesUnreachableException -import org.graphframes.connect.proto.ColumnOrExpression -import org.graphframes.connect.proto.ColumnOrExpression.ColOrExprCase -import org.graphframes.connect.proto.GraphFramesAPI -import org.graphframes.connect.proto.GraphFramesAPI.MethodCase -import org.graphframes.connect.proto.StringOrLongID -import org.graphframes.connect.proto.StringOrLongID.IdCase +import org.graphframes.connect.proto import scala.jdk.CollectionConverters.* +/** + * Utility object providing helper methods for parsing and transforming data structures related to + * GraphFrames and enabling interaction between GraphFrames and Spark Connect APIs. + * + * The methods in this object are intended for internal use within the GraphFrames module + * (`private[graphframes]`) to support parsing, transformation, and execution of GraphFrame API + * calls based on serialized or protocol buffer inputs. + */ object GraphFramesConnectUtils { + + /** + * Parses a protobuf StorageLevel object and converts it to a corresponding Spark StorageLevel. + * + * @param pbStorageLevel + * the protobuf StorageLevel object to be parsed + * @return + * the corresponding Spark StorageLevel + */ + private[graphframes] def parseStorageLevel(pbStorageLevel: proto.StorageLevel): StorageLevel = { + pbStorageLevel.getStorageLevelCase match { + case proto.StorageLevel.StorageLevelCase.DISK_ONLY => StorageLevel.DISK_ONLY + case proto.StorageLevel.StorageLevelCase.DISK_ONLY_2 => StorageLevel.DISK_ONLY_2 + case proto.StorageLevel.StorageLevelCase.DISK_ONLY_3 => StorageLevel.DISK_ONLY_3 + case proto.StorageLevel.StorageLevelCase.MEMORY_AND_DISK => StorageLevel.MEMORY_AND_DISK_SER + case proto.StorageLevel.StorageLevelCase.MEMORY_AND_DISK_2 => + StorageLevel.MEMORY_AND_DISK_SER_2 + case proto.StorageLevel.StorageLevelCase.MEMORY_AND_DISK_DESER => + StorageLevel.MEMORY_AND_DISK + case proto.StorageLevel.StorageLevelCase.MEMORY_ONLY => StorageLevel.MEMORY_ONLY_SER + case proto.StorageLevel.StorageLevelCase.MEMORY_ONLY_2 => StorageLevel.MEMORY_ONLY_SER_2 + case _ => throw new GraphFramesUnreachableException() + } + } + + /** + * Parses a proto.ColumnOrExpression object and converts it to a corresponding Spark Column. + * + * @param colOrExpr + * the proto.ColumnOrExpression object to be parsed + * @param planner + * the SparkConnectPlanner used for transforming expressions + * @return + * the resulting Spark Column + */ private[graphframes] def parseColumnOrExpression( - colOrExpr: ColumnOrExpression, + colOrExpr: proto.ColumnOrExpression, planner: SparkConnectPlanner): Column = { colOrExpr.getColOrExprCase match { - case ColOrExprCase.COL => + case proto.ColumnOrExpression.ColOrExprCase.COL => SparkShims.createColumn( planner.transformExpression( org.apache.spark.connect.proto.Expression.parseFrom(colOrExpr.getCol.toByteArray))) - case ColOrExprCase.EXPR => expr(colOrExpr.getExpr) + case proto.ColumnOrExpression.ColOrExprCase.EXPR => expr(colOrExpr.getExpr) case _ => - throw new RuntimeException( - "INTERNAL ERROR: unreachable case in function parseColumnOrExpression") + throw new GraphFramesUnreachableException() } } - private[graphframes] def parseLongOrStringID(id: StringOrLongID): Any = { + /** + * Converts a proto.StringOrLongID object to its corresponding Scala representation. + * + * @param id + * the proto.StringOrLongID object to be parsed + * @return + * the Scala representation of the ID (String or Long) + * @throws GraphFramesUnreachableException + * if the ID case is unrecognized + */ + private[graphframes] def parseLongOrStringID(id: proto.StringOrLongID): Any = { id.getIdCase match { - case IdCase.LONG_ID => id.getLongId - case IdCase.STRING_ID => id.getStringId + case proto.StringOrLongID.IdCase.LONG_ID => id.getLongId + case proto.StringOrLongID.IdCase.STRING_ID => id.getStringId case _ => - throw new RuntimeException( - "INTERNAL ERROR: unreachable case in function parseLongOrStringID") + throw new GraphFramesUnreachableException() } } + /** + * Parses the given serialized data to construct a Spark DataFrame. + * + * @param data + * the serialized representation of the DataFrame in ByteString format. Must not be empty. + * @param planner + * the SparkConnectPlanner instance used to transform the serialized plan into a Spark + * DataFrame. + * @return + * the resulting Spark DataFrame created from the provided data. + * @throws IllegalArgumentException + * if the given data is empty. + */ private[graphframes] def parseDataFrame( data: ByteString, planner: SparkConnectPlanner): DataFrame = { @@ -58,8 +118,18 @@ object GraphFramesConnectUtils { org.apache.spark.connect.proto.Plan.parseFrom(data.toByteArray).getRoot)) } + /** + * Extracts a GraphFrame from the provided GraphFramesAPI message using the specified planner. + * + * @param apiMessage + * the GraphFramesAPI protobuf message containing serialized vertices and edges + * @param planner + * the SparkConnectPlanner used for parsing and constructing DataFrames + * @return + * the constructed GraphFrame consisting of vertices and edges + */ private[graphframes] def extractGraphFrame( - apiMessage: GraphFramesAPI, + apiMessage: proto.GraphFramesAPI, planner: SparkConnectPlanner): GraphFrame = { val vertices = parseDataFrame(apiMessage.getVertices, planner) val edges = parseDataFrame(apiMessage.getEdges, planner) @@ -67,27 +137,62 @@ object GraphFramesConnectUtils { GraphFrame(vertices, edges) } + /** + * Parses a GraphFrames API call from a protocol buffer message and executes the corresponding + * operation on the GraphFrame object obtained from the planner. + * + * @param apiMessage + * The protocol buffer message that defines the GraphFrames API operation and its parameters. + * @param planner + * A SparkConnectPlanner instance used to translate protocol buffer expressions into Spark SQL + * objects (e.g., DataFrame, Column). + * @return + * A DataFrame that represents the result of the executed GraphFrame operation. + */ private[graphframes] def parseAPICall( - apiMessage: GraphFramesAPI, + apiMessage: proto.GraphFramesAPI, planner: SparkConnectPlanner): DataFrame = { val graphFrame = extractGraphFrame(apiMessage, planner) apiMessage.getMethodCase match { - case MethodCase.AGGREGATE_MESSAGES => { + case proto.GraphFramesAPI.MethodCase.AGGREGATE_MESSAGES => { val aggregateMessagesProto = apiMessage.getAggregateMessages var aggregateMessages = graphFrame.aggregateMessages - if (aggregateMessagesProto.hasSendToDst) { + if (aggregateMessagesProto.getSendToDstList.size() == 1) { aggregateMessages = aggregateMessages.sendToDst( - parseColumnOrExpression(aggregateMessagesProto.getSendToDst, planner)) + parseColumnOrExpression(aggregateMessagesProto.getSendToDst(0), planner)) + } else if (aggregateMessagesProto.getSendToDstList.size() > 1) { + val sendToDst = aggregateMessagesProto.getSendToDstList.asScala.map( + parseColumnOrExpression(_, planner)) + aggregateMessages = + aggregateMessages.sendToDst(sendToDst.head, sendToDst.tail.toSeq: _*) } - if (aggregateMessagesProto.hasSendToSrc) { + if (aggregateMessagesProto.getSendToSrcList.size() == 1) { aggregateMessages = aggregateMessages.sendToSrc( - parseColumnOrExpression(aggregateMessagesProto.getSendToSrc, planner)) + parseColumnOrExpression(aggregateMessagesProto.getSendToSrc(0), planner)) + } else if (aggregateMessagesProto.getSendToSrcList.size() > 1) { + val sendToSrc = aggregateMessagesProto.getSendToSrcList.asScala.map( + parseColumnOrExpression(_, planner)) + aggregateMessages = + aggregateMessages.sendToSrc(sendToSrc.head, sendToSrc.tail.toSeq: _*) } - aggregateMessages.agg(parseColumnOrExpression(aggregateMessagesProto.getAggCol, planner)) + if (aggregateMessagesProto.hasStorageLevel) { + aggregateMessages = aggregateMessages.setIntermediateStorageLevel( + parseStorageLevel(aggregateMessagesProto.getStorageLevel)) + } + + val aggCols = + aggregateMessagesProto.getAggColList.asScala.map(parseColumnOrExpression(_, planner)) + + // At least one agg col is required, and it is easier to check it on the client side + if (aggCols.size == 1) { + aggregateMessages.agg(aggCols.head) + } else { + aggregateMessages.agg(aggCols.head, aggCols.tail.toSeq: _*) + } } - case MethodCase.BFS => { + case proto.GraphFramesAPI.MethodCase.BFS => { val bfsProto = apiMessage.getBfs graphFrame.bfs .toExpr(parseColumnOrExpression(bfsProto.getToExpr, planner)) @@ -96,34 +201,65 @@ object GraphFramesConnectUtils { .maxPathLength(bfsProto.getMaxPathLength) .run() } - case MethodCase.CONNECTED_COMPONENTS => { + case proto.GraphFramesAPI.MethodCase.CONNECTED_COMPONENTS => { val cc = apiMessage.getConnectedComponents - graphFrame.connectedComponents + val ccBuilder = graphFrame.connectedComponents + .maxIter(cc.getMaxIter) .setAlgorithm(cc.getAlgorithm) .setCheckpointInterval(cc.getCheckpointInterval) .setBroadcastThreshold(cc.getBroadcastThreshold) + .setUseLocalCheckpoints(cc.getUseLocalCheckpoints) .setUseLabelsAsComponents(cc.getUseLabelsAsComponents) - .run() + + if (cc.hasStorageLevel) { + ccBuilder.setIntermediateStorageLevel(parseStorageLevel(cc.getStorageLevel)).run() + } else { + ccBuilder.run() + } + } + + case proto.GraphFramesAPI.MethodCase.DETECTING_CYCLES => { + val dc = apiMessage.getDetectingCycles + val dcBuilder = graphFrame.detectingCycles + .setCheckpointInterval(dc.getCheckpointInterval) + .setUseLocalCheckpoints(dc.getUseLocalCheckpoints) + if (dc.hasStorageLevel) { + dcBuilder.setIntermediateStorageLevel(parseStorageLevel(dc.getStorageLevel)).run() + } else { + dcBuilder.run() + } } - case MethodCase.DROP_ISOLATED_VERTICES => { + + case proto.GraphFramesAPI.MethodCase.DROP_ISOLATED_VERTICES => { graphFrame.dropIsolatedVertices().vertices } - case MethodCase.FILTER_EDGES => { + case proto.GraphFramesAPI.MethodCase.FILTER_EDGES => { val condition = parseColumnOrExpression(apiMessage.getFilterEdges.getCondition, planner) graphFrame.filterEdges(condition).edges } - case MethodCase.FILTER_VERTICES => { + case proto.GraphFramesAPI.MethodCase.FILTER_VERTICES => { val condition = parseColumnOrExpression(apiMessage.getFilterVertices.getCondition, planner) graphFrame.filterVertices(condition).vertices } - case MethodCase.FIND => { + case proto.GraphFramesAPI.MethodCase.FIND => { graphFrame.find(apiMessage.getFind.getPattern) } - case MethodCase.LABEL_PROPAGATION => { - graphFrame.labelPropagation.maxIter(apiMessage.getLabelPropagation.getMaxIter).run() + case proto.GraphFramesAPI.MethodCase.LABEL_PROPAGATION => { + val lp = apiMessage.getLabelPropagation + val lpBuilder = graphFrame.labelPropagation + .maxIter(lp.getMaxIter) + .setAlgorithm(lp.getAlgorithm) + .setCheckpointInterval(lp.getCheckpointInterval) + .setUseLocalCheckpoints(lp.getUseLocalCheckpoints) + + if (lp.hasStorageLevel) { + lpBuilder.setIntermediateStorageLevel(parseStorageLevel(lp.getStorageLevel)).run() + } else { + lpBuilder.run() + } } - case MethodCase.PAGE_RANK => { + case proto.GraphFramesAPI.MethodCase.PAGE_RANK => { val pageRankProto = apiMessage.getPageRank val pageRank = graphFrame.pageRank.resetProbability(pageRankProto.getResetProbability) @@ -142,7 +278,7 @@ object GraphFramesConnectUtils { // see comments in the Python API pageRank.run().vertices } - case MethodCase.PARALLEL_PERSONALIZED_PAGE_RANK => { + case proto.GraphFramesAPI.MethodCase.PARALLEL_PERSONALIZED_PAGE_RANK => { val pPageRankProto = apiMessage.getParallelPersonalizedPageRank val sourceIds = pPageRankProto.getSourceIdsList.asScala .map(parseLongOrStringID) @@ -155,7 +291,7 @@ object GraphFramesConnectUtils { .run() .vertices // See comment in the PageRank } - case MethodCase.POWER_ITERATION_CLUSTERING => { + case proto.GraphFramesAPI.MethodCase.POWER_ITERATION_CLUSTERING => { val pic = apiMessage.getPowerIterationClustering if (pic.hasWeightCol) { graphFrame.powerIterationClustering(pic.getK, pic.getMaxIter, Some(pic.getWeightCol)) @@ -163,7 +299,7 @@ object GraphFramesConnectUtils { graphFrame.powerIterationClustering(pic.getK, pic.getMaxIter, None) } } - case MethodCase.PREGEL => { + case proto.GraphFramesAPI.MethodCase.PREGEL => { val pregelProto = apiMessage.getPregel var pregel = graphFrame.pregel .aggMsgs(parseColumnOrExpression(pregelProto.getAggMsgs, planner)) @@ -173,6 +309,31 @@ object GraphFramesConnectUtils { parseColumnOrExpression(pregelProto.getAdditionalColInitial, planner), parseColumnOrExpression(pregelProto.getAdditionalColUpd, planner)) .setMaxIter(pregelProto.getMaxIter) + .setUseLocalCheckpoints(pregelProto.getUseLocalCheckpoints) + + if (pregelProto.hasStorageLevel) { + pregel = + pregel.setIntermediateStorageLevel(parseStorageLevel(pregelProto.getStorageLevel)) + } + + if (pregelProto.hasInitialActiveExpr) { + // We are not checking here that all the attrs are present; + // Check should be done on the client side. + pregel = pregel + .setInitialActiveVertexExpression( + parseColumnOrExpression(pregelProto.getInitialActiveExpr, planner)) + .setUpdateActiveVertexExpression( + parseColumnOrExpression(pregelProto.getUpdateActiveExpr, planner)) + + if (pregelProto.hasSkipMessagesFromNonActive) { + pregel = pregel.setSkipMessagesFromNonActiveVertices( + pregelProto.getSkipMessagesFromNonActive) + } + + if (pregelProto.hasStopIfAllNonActive) { + pregel = pregel.setStopIfAllNonActiveVertices(pregelProto.getStopIfAllNonActive) + } + } pregel = pregelProto.getSendMsgToSrcList.asScala .map(parseColumnOrExpression(_, planner)) @@ -187,18 +348,29 @@ object GraphFramesConnectUtils { pregel.run() } - case MethodCase.SHORTEST_PATHS => { - graphFrame.shortestPaths + case proto.GraphFramesAPI.MethodCase.SHORTEST_PATHS => { + val spBuilder = graphFrame.shortestPaths .landmarks( apiMessage.getShortestPaths.getLandmarksList.asScala.map(parseLongOrStringID).toSeq) - .run() + .setAlgorithm(apiMessage.getShortestPaths.getAlgorithm) + .setCheckpointInterval(apiMessage.getShortestPaths.getCheckpointInterval) + .setUseLocalCheckpoints(apiMessage.getShortestPaths.getUseLocalCheckpoints) + + if (apiMessage.getShortestPaths.hasStorageLevel) { + spBuilder + .setIntermediateStorageLevel( + parseStorageLevel(apiMessage.getShortestPaths.getStorageLevel)) + .run() + } else { + spBuilder.run() + } } - case MethodCase.STRONGLY_CONNECTED_COMPONENTS => { + case proto.GraphFramesAPI.MethodCase.STRONGLY_CONNECTED_COMPONENTS => { graphFrame.stronglyConnectedComponents .maxIter(apiMessage.getStronglyConnectedComponents.getMaxIter) .run() } - case MethodCase.SVD_PLUS_PLUS => { + case proto.GraphFramesAPI.MethodCase.SVD_PLUS_PLUS => { val svdPPProto = apiMessage.getSvdPlusPlus val svd = graphFrame.svdPlusPlus .maxIter(svdPPProto.getMaxIter) @@ -212,10 +384,19 @@ object GraphFramesConnectUtils { val svdResult = svd.run() svdResult.withColumn("loss", lit(svd.loss)) } - case MethodCase.TRIANGLE_COUNT => { - graphFrame.triangleCount.run() + case proto.GraphFramesAPI.MethodCase.TRIANGLE_COUNT => { + val trCounter = graphFrame.triangleCount + + if (apiMessage.getTriangleCount.hasStorageLevel) { + trCounter + .setIntermediateStorageLevel( + parseStorageLevel(apiMessage.getTriangleCount.getStorageLevel)) + .run() + } else { + trCounter.run() + } } - case MethodCase.TRIPLETS => { + case proto.GraphFramesAPI.MethodCase.TRIPLETS => { graphFrame.triplets } case _ => throw new GraphFramesUnreachableException() // Unreachable diff --git a/docs/src/01-about/01-index.md b/docs/src/01-about/01-index.md index 663cda7c7..5073018c3 100644 --- a/docs/src/01-about/01-index.md +++ b/docs/src/01-about/01-index.md @@ -10,28 +10,28 @@ GraphFrames represent graphs: vertices (e.g., users) and edges (e.g., relationsh GraphFrames provides most of the algorithm and routines in two ways: -- Native DataFrame based implementation; +- Native DataFrame-based implementation; - Wrapper over GraphX implementation. -**NOTE:** GraphX is deprecated in the upstream Apache Spark and is not maintained anymore. GraphFrames project come with it's own fork of GraphX: `org.apache.spark.graphframes.graphx`. While we are trying do not make any breaking changes in GraphFrames' GraphX, it is still considered as a part of the internal API. The best way to use it is via GraphFrame-GraphX conversion utils, instead of directly manipulate GraphX structures. +**NOTE:** GraphX is deprecated in the upstream Apache Spark and is not maintained anymore. GraphFrames project comes with its own fork of GraphX: `org.apache.spark.graphframes.graphx`. While we are trying to not make any breaking changes in GraphFrames' GraphX, it is still considered as a part of the internal API. The best way to use it is via GraphFrame-GraphX conversion utils, instead of directly manipulate GraphX structures. ### Graph Representation - GraphX represents graphs by the pair of `RDD`: `VertexRDD` and `EdgeRDD`. - GraphFrames represent graphs by the pair of `DataFrame`: `vertices` and `edges`. -While `RDD` may provide slightly more flexible API and, in theory, processing of RDDs may be faster, they requires much more memory to process them. For example, `VertexRDD[Unit]` that contains de-facto only `Long` vertex IDs will require much more memory to store and process compared to the `DataFrame` of vertices with a single `Long` column. The reason is serialization of `RDD` are done by serializing the underlying JVM objects, but serialization of data in `DataFrame` rely on the `Thungsten` with it's own serialization format. On bechmarks, memory overhead of serializing Java objects may be up to five times, while the compute overhead of creating JVM objects from thungsten format is less than 10-15%. +While `RDD` may provide slightly more flexible API and, in theory, processing of RDDs may be faster, they require much more memory to process them. For example, `VertexRDD[Unit]` that contains de-facto only `Long` vertex IDs will require much more memory to store and process compared to the `DataFrame` of vertices with a single `Long` column. The reason is serialization of `RDD` are done by serializing the underlying JVM objects, but serialization of data in `DataFrame` rely on the `Tungsten` with its own serialization format. On benchmarks, memory overhead of serializing Java objects may be up to five times, while the compute overhead of creating JVM objects from tungsten format is less than 10–15%. ### Optimizations -- GraphX rely on it's own partitioning strategy and building and maintaining partitions index. +- GraphX relies on its own partitioning strategy and building and maintaining partition index. - GraphFrames rely on the Apache Spark Catalyst optimizer and Adaptive Query Execution. -In most of the cases that include real-world complex tranformations, especially on really big data, Catalyst + AQE will provide better results compared to manual index of partitions. +In most of the cases that include real-world complex transformations, especially on huge data, Catalyst + AQE will provide better results compared to the manual index of partitions. -### If DataFrames are better, why GraphFrames still provides conversion methods? +### If DataFrames are better, why do GraphFrames still provide conversion methods? -Our [benhmarks](03-benchmarks.md) shows that on small and medium graphs GraphX may be better choice. With GraphX users can sacrifice memory consumption if favor of better running time without query optimization overhead. That may be suitable, for example, for Spark Structured Streaming scenarios. +Our [benchmarks](03-benchmarks.md) show that on small and medium graphs GraphX may be a better choice. With GraphX users can sacrifice memory consumption in favor of better running time without query optimization overhead. That may be suitable, for example, for Spark Structured Streaming scenarios. # Use-cases of GraphFrames @@ -39,7 +39,7 @@ Refer to the [User Guide](/04-user-guide/01-creating-graphframes.md) for a full ## Ranking in search systems -`PageRank` is a fundamental algorithm originally developed by Google for ranking web pages in search results. It works by measuring the importance of nodes in a graph based on the link structure, where links from highly-ranked pages contribute more to the rank of target pages. This principle can be extended to ranking documents in search systems, where documents are treated as nodes and hyperlinks or semantic relationships as edges. +`PageRank` is a fundamental algorithm originally developed by Google for ranking web pages in search results. It works by measuring the importance of nodes in a graph based on the link structure, where links from highly ranked pages contribute more to the rank of target pages. This principle can be extended to ranking documents in search systems, where documents are treated as nodes and hyperlinks or semantic relationships as edges. GraphFrames provides a fully distributed Spark-based implementation of the `PageRank` algorithm, enabling efficient computation of document rankings at scale. This implementation leverages the power of Apache Spark's distributed computing model, allowing organizations to analyze large-scale document networks without sacrificing performance. diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/classic/aggregate_messages.py similarity index 100% rename from python/graphframes/lib/aggregate_messages.py rename to python/graphframes/classic/aggregate_messages.py diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index af10a9cab..11901d010 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations import sys from typing import Any, Optional, Union @@ -21,11 +22,12 @@ if sys.version > "3": basestring = str -from pyspark import SparkContext -from pyspark.sql import Column, DataFrame, SparkSession +from pyspark.sql.classic.column import Column, _to_seq +from pyspark.sql.classic.dataframe import DataFrame, SparkContext, SparkSession from pyspark.storagelevel import StorageLevel -from graphframes.lib import Pregel +from graphframes.classic.pregel import Pregel +from graphframes.classic.utils import storage_level_to_jvm def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": @@ -41,6 +43,10 @@ def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" + if jsc._jvm is None: + raise RuntimeError( + "Spark Driver's JVM is dead or did not start properly. See driver logs for details." + ) return ( jsc._jvm.Thread.currentThread() .getContextClassLoader() @@ -159,7 +165,11 @@ def dropIsolatedVertices(self) -> "GraphFrame": return _from_java_gf(jdf, self._spark) def bfs( - self, fromExpr: str, toExpr: str, edgeFilter: Optional[str] = None, maxPathLength: int = 10 + self, + fromExpr: str, + toExpr: str, + edgeFilter: Optional[str] = None, + maxPathLength: int = 10, ) -> DataFrame: builder = ( self._jvm_graph.bfs().fromExpr(fromExpr).toExpr(toExpr).maxPathLength(maxPathLength) @@ -171,35 +181,66 @@ def bfs( def aggregateMessages( self, - aggCol: Union[Column, str], - sendToSrc: Union[Column, str, None] = None, - sendToDst: Union[Column, str, None] = None, + aggCol: list[Column | str], + sendToSrc: list[Column | str], + sendToDst: list[Column | str], + intermediate_storage_level: StorageLevel, ) -> DataFrame: - # Check that either sendToSrc, sendToDst, or both are provided - if sendToSrc is None and sendToDst is None: - raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") builder = self._jvm_graph.aggregateMessages() - if sendToSrc is not None: - if isinstance(sendToSrc, Column): - builder.sendToSrc(sendToSrc._jc) - elif isinstance(sendToSrc, basestring): - builder.sendToSrc(sendToSrc) + builder = builder.setIntermediateStorageLevel( + storage_level_to_jvm(intermediate_storage_level, self._spark) + ) + if len(sendToSrc) == 1: + if isinstance(sendToSrc[0], Column): + builder.sendToSrc(sendToSrc[0]._jc) + elif isinstance(sendToSrc[0], basestring): + builder.sendToSrc(sendToSrc[0]) else: raise TypeError("Provide message either as `Column` or `str`") - if sendToDst is not None: - if isinstance(sendToDst, Column): - builder.sendToDst(sendToDst._jc) - elif isinstance(sendToDst, basestring): - builder.sendToDst(sendToDst) + elif len(sendToSrc) > 1: + if all(isinstance(x, Column) for x in sendToSrc): + send2src = [x._jc for x in sendToSrc] + builder.sendToSrc(send2src[0], _to_seq(self._sc, send2src[1:])) + elif all(isinstance(x, basestring) for x in sendToSrc): + builder.sendToSrc(sendToSrc[0], _to_seq(self._sc, sendToSrc[1:])) + else: + raise TypeError( + "Multiple messages should all be `Column` or `str`, not a mix of them." + ) + + if len(sendToDst) == 1: + if isinstance(sendToDst[0], Column): + builder.sendToDst(sendToDst[0]._jc) + elif isinstance(sendToDst[0], basestring): + builder.sendToDst(sendToDst[0]) else: raise TypeError("Provide message either as `Column` or `str`") - if isinstance(aggCol, Column): - jdf = builder.agg(aggCol._jc) - else: - jdf = builder.agg(aggCol) - return DataFrame(jdf, self._spark) + elif len(sendToDst) > 1: + if all(isinstance(x, Column) for x in sendToDst): + send2dst = [x._jc for x in sendToDst] + builder.sendToDst(send2dst[0], _to_seq(self._sc, send2dst[1:])) + elif all(isinstance(x, basestring) for x in sendToDst): + builder.sendToDst(sendToDst[0], _to_seq(self._sc, sendToDst[1:])) + else: + raise TypeError( + "Multiple messages should all be `Column` or `str`, not a mix of them." + ) - # Standard algorithms + if len(aggCol) == 1: + if isinstance(aggCol[0], Column): + jdf = builder.aggCol(aggCol[0]._jc) + elif isinstance(aggCol[0], basestring): + jdf = builder.aggCol(aggCol[0]) + elif len(aggCol) > 1: + if all(isinstance(x, Column) for x in aggCol): + jdf = builder.aggCol(aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol])) + elif all(isinstance(x, basestring) for x in aggCol): + jdf = builder.aggCol(aggCol[0], _to_seq(self._sc, aggCol)) + else: + raise TypeError( + "Multiple agg cols should all be `Column` or `str`, not a mix of them." + ) + return DataFrame(jdf, self._spark) def connectedComponents( self, diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/classic/pregel.py similarity index 100% rename from python/graphframes/lib/pregel.py rename to python/graphframes/classic/pregel.py diff --git a/python/graphframes/classic/utils.py b/python/graphframes/classic/utils.py new file mode 100644 index 000000000..863ca0c5f --- /dev/null +++ b/python/graphframes/classic/utils.py @@ -0,0 +1,16 @@ +from py4j.java_gateway import JavaObject +from pyspark.storagelevel import StorageLevel +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.sql.classic.dataframe import SparkSession + + +def storage_level_to_jvm(storage_level: StorageLevel, spark: SparkSession) -> JavaObject: + return spark._jvm.org.apache.spark.storage.StorageLevel.apply( + storage_level.useDisk, + storage_level.useMemory, + storage_level.useOffHeap, + storage_level.deserialized, + storage_level.replication, + ) diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index edb960728..a9d7671d8 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -15,7 +15,12 @@ from typing_extensions import Self from .proto import graphframes_pb2 as pb -from .utils import dataframe_to_proto, make_column_or_expr, make_str_or_long_id +from .utils import ( + dataframe_to_proto, + make_column_or_expr, + make_str_or_long_id, + storage_level_to_proto, +) # Spark 4 removed the withPlan method in favor of the constructor, but Spark 3 @@ -31,6 +36,31 @@ def _dataframe_from_plan(plan: LogicalPlan, session: SparkSession) -> DataFrame: class PregelConnect: + """Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. + + See `Malewicz et al., Pregel: a system for large-scale graph processing `_ + for a detailed description of the Pregel algorithm. + + You can construct a Pregel instance using either this constructor or :attr:`graphframes.GraphFrame.pregel`, + then use builder pattern to describe the operations, and then call :func:`run` to start a run. + It returns a DataFrame of vertices from the last iteration. + + When a run starts, it expands the vertices DataFrame using column expressions defined by :func:`withVertexColumn`. + Those additional vertex properties can be changed during Pregel iterations. + In each Pregel iteration, there are three phases: + - Given each edge triplet, generate messages and specify target vertices to send, + described by :func:`sendMsgToDst` and :func:`sendMsgToSrc`. + - Aggregate messages by target vertex IDs, described by :func:`aggMsgs`. + - Update additional vertex properties based on aggregated messages and states from previous iteration, + described by :func:`withVertexColumn`. + + Please find what columns you can reference at each phase in the method API docs. + + You can control the number of iterations by :func:`setMaxIter` and check API docs for advanced controls. + + :param graph: a :class:`graphframes.GraphFrame` object holding a graph with vertices and edges stored as DataFrames. + """ # noqa: E501 + def __init__(self, graph: "GraphFrameConnect") -> None: self.graph = graph self._max_iter = 10 @@ -42,16 +72,42 @@ def __init__(self, graph: "GraphFrameConnect") -> None: self._send_msg_to_dst = [] self._agg_msg = None self._early_stopping = False + self._use_local_checkpoints = False + self._storage_level = StorageLevel.MEMORY_AND_DISK_DESER + self._initial_active_expr: Column | str | None = None + self._update_active_expr: Column | str | None = None + self._stop_if_all_non_active = False + self._skip_messages_from_non_active = False def setMaxIter(self, value: int) -> Self: + """Sets the max number of iterations (default: 2).""" self._max_iter = value return self def setCheckpointInterval(self, value: int) -> Self: + """Sets the number of iterations between two checkpoints (default: 2). + + This is an advanced control to balance query plan optimization and checkpoint data I/O cost. + In most cases, you should keep the default value. + + Checkpoint is disabled if this is set to 0. + """ self._checkpoint_interval = value return self def setEarlyStopping(self, value: bool) -> Self: + """Set should Pregel stop earlier in case of no new messages to send or not. + + Early stopping allows to terminate Pregel before reaching maxIter by checking if there are any non-null messages. + While in some cases it may gain significant performance boost, in other cases it can lead to performance degradation, + because checking if the messages DataFrame is empty or not is an action and requires materialization of the Spark Plan + with some additional computations. + + In the case when the user can assume a good value of maxIter, it is recommended to leave this value to the default "false". + In the case when it is hard to estimate the number of iterations required for convergence, + it is recommended to set this value to "false" to avoid iterating over convergence until reaching maxIter. + When this value is "true", maxIter can be set to a bigger value without risks. + """ # noqa: E501 self._early_stopping = value return self @@ -61,23 +117,138 @@ def withVertexColumn( initialExpr: Column | str, updateAfterAggMsgsExpr: Column | str, ) -> Self: + """Defines an additional vertex column at the start of run and how to update it in each iteration. + + You can call it multiple times to add more than one additional vertex columns. + + :param colName: the name of the additional vertex column. + It cannot be an existing vertex column in the graph. + :param initialExpr: the expression to initialize the additional vertex column. + You can reference all original vertex columns in this expression. + :param updateAfterAggMsgsExpr: the expression to update the additional vertex column after messages aggregation. + You can reference all original vertex columns, additional vertex columns, and the + aggregated message column using :func:`msg`. + If the vertex received no messages, the message column would be null. + """ # noqa: E501 self._col_name = colName self._initial_expr = initialExpr self._update_after_agg_msgs_expr = updateAfterAggMsgsExpr return self def sendMsgToSrc(self, msgExpr: Column | str) -> Self: + """Defines a message to send to the source vertex of each edge triplet. + + You can call it multiple times to send more than one messages. + + See method :func:`sendMsgToDst`. + + :param msgExpr: the expression of the message to send to the source vertex given a (src, edge, dst) triplet. + Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, + and `edge`, respectively. + You can reference them using :func:`src`, :func:`dst`, and :func:`edge`. + Null messages are not included in message aggregation. + """ # noqa: E501 self._send_msg_to_src.append(msgExpr) return self def sendMsgToDst(self, msgExpr: Column | str) -> Self: + """Defines a message to send to the destination vertex of each edge triplet. + + You can call it multiple times to send more than one messages. + + See method :func:`sendMsgToSrc`. + + :param msgExpr: the message expression to send to the destination vertex given a (`src`, `edge`, `dst`) triplet. + Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, + and `edge`, respectively. + You can reference them using :func:`src`, :func:`dst`, and :func:`edge`. + Null messages are not included in message aggregation. + """ # noqa: E501 self._send_msg_to_dst.append(msgExpr) return self def aggMsgs(self, aggExpr: Column) -> Self: + """Defines how messages are aggregated after grouped by target vertex IDs. + + :param aggExpr: the message aggregation expression, such as `sum(Pregel.msg())`. + You can reference the message column by :func:`msg` and the vertex ID by `col("id")`, + while the latter is usually not used. + """ # noqa: E501 self._agg_msg = aggExpr return self + def setStopIfAllNonActiveVertices(self, value: bool) -> Self: + """Set should Pregel stop if all the vertices voted to halt. + + Activity (or vote) is determined based on the activity_col. + See methods :func:`setInitialActiveVertexExpression` and :func:`setUpdateActiveVertexExpression` for details + how to set and update activity_col. + + Be aware that checking of the vote is not free but a Spark Action. In case the + condition is not realistically reachable but set, it will just slow down the algorithm. + + :param value: the boolean value. + """ # noqa: E501 + self._stop_if_all_non_active = value + return self + + def setInitialActiveVertexExpression(self, value: Column | str) -> Self: + """Sets the initial expression for the active vertex column. + + The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. + This expression is evaluated on the initial vertices DataFrame to set the initial state of the activity column. + + :param value: expression to compute the initial active state of vertices. + You can reference all original vertex columns in this expression. + """ # noqa: E501 + self._initial_active_expr = value + return self + + def setUpdateActiveVertexExpression(self, value: Column | str) -> Self: + """Sets the expression to update the active vertex column. + + The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. + This expression is evaluated on the updated vertices DataFrame to set the new state of the activity column. + + :param value: expression to compute the new active state of vertices. + You can reference all original vertex columns and additional vertex columns in this expression. + """ # noqa: E501 + self._update_active_expr = value + return self + + def setSkipMessagesFromNonActiveVertices(self, value: bool) -> Self: + """Set should Pregel skip sending messages from non-active vertices. + + When this option is enabled, messages will not be sent from vertices that are marked as inactive. + This can help optimize performance by avoiding unnecessary message propagation from inactive vertices. + + :param value: boolean value. + """ # noqa: E501 + self._skip_messages_from_non_active = value + return self + + def setUseLocalCheckpoints(self, value: bool) -> Self: + """Set should Pregel use local checkpoints. + + Local checkpoints are faster and do not require configuring a persistent storage. + At the same time, local checkpoints are less reliable and may create a big load on local disks of executors. + + :param value: boolean value. + """ # noqa: E501 + self._use_local_checkpoints = value + return self + + def setIntermediateStorageLevel(self, storage_level: StorageLevel) -> Self: + """Set the intermediate storage level. + On each iteration, Pregel cache results with a requested storage level. + + For very big graphs it is recommended to use DISK_ONLY. + + :param storage_level: storage level to use. + """ # noqa: E501 + self._storage_level = storage_level + return self + def run(self) -> DataFrame: class Pregel(LogicalPlan): def __init__( @@ -91,6 +262,12 @@ def __init__( send2src: list[Column | str], vertex_col_init: Column | str, vertex_col_upd: Column | str, + use_local_checkpoints: bool, + storage_level: StorageLevel, + initial_active_col: Column | str | None, + update_active_col: Column | str | None, + stop_if_all_non_active: bool, + skip_message_from_non_active: bool, vertices: DataFrame, edges: DataFrame, ) -> None: @@ -104,6 +281,12 @@ def __init__( self.send2src = send2src self.vertex_col_init = vertex_col_init self.vertex_col_upd = vertex_col_upd + self.use_local_checkpoints = use_local_checkpoints + self.storage_level = storage_level + self.initial_active_expr = initial_active_col + self.update_active_expr = update_active_col + self.stop_if_all_non_active = stop_if_all_non_active + self.skip_message_from_non_active = skip_message_from_non_active self.vertices = vertices self.edges = edges @@ -122,6 +305,16 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: additional_col_initial=make_column_or_expr(self.vertex_col_init, session), additional_col_upd=make_column_or_expr(self.vertex_col_upd, session), early_stopping=self.early_stopping, + use_local_checkpoints=self.use_local_checkpoints, + storage_level=storage_level_to_proto(self.storage_level), + stop_if_all_non_active=self.stop_if_all_non_active, + skip_messages_from_non_active=self.skip_message_from_non_active, + initial_active_expr=make_column_or_expr(self.initial_active_expr, session) + if self.initial_active_expr is not None + else None, + update_active_expr=make_column_or_expr(self.update_active_expr, session) + if self.update_active_expr is not None + else None, ) pb_message = pb.GraphFramesAPI( vertices=dataframe_to_proto(self.vertices, session), @@ -152,9 +345,15 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: agg_msg=self._agg_msg, send2dst=self._send_msg_to_dst, send2src=self._send_msg_to_src, + early_stopping=self._early_stopping, + use_local_checkpoints=self._use_local_checkpoints, + initial_active_col=self._initial_active_expr, + update_active_col=self._update_active_expr, + stop_if_all_non_active=self._stop_if_all_non_active, + skip_message_from_non_active=self._skip_messages_from_non_active, + storage_level=self._storage_level, vertices=self.graph._vertices, edges=self.graph._edges, - early_stopping=self._early_stopping, ), session=self.graph._spark, ) @@ -432,7 +631,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan if edgeFilter is None: - edgeFilter = F.lit(True) + edgeFilter: Column = F.lit(True) return _dataframe_from_plan( BFS( @@ -448,18 +647,20 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def aggregateMessages( self, - aggCol: Column | str, - sendToSrc: Column | str | None = None, - sendToDst: Column | str | None = None, + aggCol: list[Column | str], + sendToSrc: list[Column | str], + sendToDst: list[Column | str], + intermediate_storage_level: StorageLevel, ) -> DataFrame: class AggregateMessages(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, - agg_col: Column | str, - send2src: Column | str | None, - send2dst: Column | str | None, + agg_col: list[Column | str], + send2src: list[Column | str], + send2dst: list[Column | str], + storage_level: StorageLevel, ) -> None: super().__init__(None) self.v = v @@ -467,6 +668,7 @@ def __init__( self.agg_col = agg_col self.send2src = send2src self.send2dst = send2dst + self.storage_level = storage_level def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( @@ -474,17 +676,10 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) graphframes_api_call.aggregate_messages.CopyFrom( pb.AggregateMessages( - agg_col=make_column_or_expr(self.agg_col, session), - send_to_src=( - None - if self.send2src is None - else make_column_or_expr(self.send2src, session) - ), - send_to_dst=( - None - if self.send2dst is None - else make_column_or_expr(self.send2dst, session) - ), + agg_col=[make_column_or_expr(x, session) for x in self.agg_col], + send_to_src=[make_column_or_expr(x, session) for x in self.send2src], + send_to_dst=[make_column_or_expr(x, session) for x in self.send2dst], + storage_level=storage_level_to_proto(self.storage_level), ) ) plan = self._create_proto_relation() @@ -495,7 +690,14 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") return _dataframe_from_plan( - AggregateMessages(self._vertices, self._edges, aggCol, sendToSrc, sendToDst), + AggregateMessages( + self._vertices, + self._edges, + aggCol, + sendToSrc, + sendToDst, + intermediate_storage_level, + ), self._spark, ) diff --git a/python/graphframes/connect/proto/graphframes_pb2.py b/python/graphframes/connect/proto/graphframes_pb2.py index 7f11707c6..a34f42a42 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.py +++ b/python/graphframes/connect/proto/graphframes_pb2.py @@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xd6\x0c\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12k\n\x16\x64rop_isolated_vertices\x18\x06 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x07 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\x08 \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\t \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12^\n\x11label_propagation\x18\n \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12\x46\n\tpage_rank\x18\x0b \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0c \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12w\n\x1apower_iteration_clustering\x18\r \x01(\x0b\x32\x37.org.graphframes.connect.proto.PowerIterationClusteringH\x00R\x18powerIterationClustering\x12?\n\x06pregel\x18\x0e \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x0f \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x10 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x11 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x12 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x13 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\xce\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold\x12\x37\n\x18use_labels_as_components\x18\x04 \x01(\x08R\x15useLabelsAsComponents"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"v\n\x18PowerIterationClustering\x12\x0c\n\x01k\x18\x01 \x01(\x05R\x01k\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12"\n\nweight_col\x18\x03 \x01(\tH\x00R\tweightCol\x88\x01\x01\x42\r\n\x0b_weight_col"\x8f\x05\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd\x12*\n\x0e\x65\x61rly_stopping\x18\t \x01(\x08H\x00R\rearlyStopping\x88\x01\x01\x42\x11\n\x0f_early_stopping"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' + b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xb3\r\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12k\n\x16\x64rop_isolated_vertices\x18\x06 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12[\n\x10\x64\x65tecting_cycles\x18\x07 \x01(\x0b\x32..org.graphframes.connect.proto.DetectingCyclesH\x00R\x0f\x64\x65tectingCycles\x12O\n\x0c\x66ilter_edges\x18\x08 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\t \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\n \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12^\n\x11label_propagation\x18\x0b \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12\x46\n\tpage_rank\x18\x0c \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\r \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12w\n\x1apower_iteration_clustering\x18\x0e \x01(\x0b\x32\x37.org.graphframes.connect.proto.PowerIterationClusteringH\x00R\x18powerIterationClustering\x12?\n\x06pregel\x18\x0f \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x10 \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x11 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x12 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x13 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x14 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"\xd7\x02\n\x0cStorageLevel\x12\x1d\n\tdisk_only\x18\x01 \x01(\x08H\x00R\x08\x64iskOnly\x12 \n\x0b\x64isk_only_2\x18\x02 \x01(\x08H\x00R\tdiskOnly2\x12 \n\x0b\x64isk_only_3\x18\x03 \x01(\x08H\x00R\tdiskOnly3\x12(\n\x0fmemory_and_disk\x18\x04 \x01(\x08H\x00R\rmemoryAndDisk\x12+\n\x11memory_and_disk_2\x18\x05 \x01(\x08H\x00R\x0ememoryAndDisk2\x12\x33\n\x15memory_and_disk_deser\x18\x06 \x01(\x08H\x00R\x12memoryAndDiskDeser\x12!\n\x0bmemory_only\x18\x07 \x01(\x08H\x00R\nmemoryOnly\x12$\n\rmemory_only_2\x18\x08 \x01(\x08H\x00R\x0bmemoryOnly2B\x0f\n\rstorage_level"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xee\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12Q\n\x0bsend_to_src\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tsendToSrc\x12Q\n\x0bsend_to_dst\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tsendToDst\x12U\n\rstorage_level\x18\x04 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x86\x03\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold\x12\x37\n\x18use_labels_as_components\x18\x04 \x01(\x08R\x15useLabelsAsComponents\x12\x32\n\x15use_local_checkpoints\x18\x05 \x01(\x08R\x13useLocalCheckpoints\x12\x19\n\x08max_iter\x18\x06 \x01(\x05R\x07maxIter\x12U\n\rstorage_level\x18\x07 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\xdf\x01\n\x0f\x44\x65tectingCycles\x12\x32\n\x15use_local_checkpoints\x18\x01 \x01(\x08R\x13useLocalCheckpoints\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12U\n\rstorage_level\x18\x03 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"\x99\x02\n\x10LabelPropagation\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x32\n\x15use_local_checkpoints\x18\x03 \x01(\x08R\x13useLocalCheckpoints\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12U\n\rstorage_level\x18\x05 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"v\n\x18PowerIterationClustering\x12\x0c\n\x01k\x18\x01 \x01(\x05R\x01k\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12"\n\nweight_col\x18\x03 \x01(\tH\x00R\tweightCol\x88\x01\x01\x42\r\n\x0b_weight_col"\xe6\t\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd\x12*\n\x0e\x65\x61rly_stopping\x18\t \x01(\x08H\x00R\rearlyStopping\x88\x01\x01\x12\x32\n\x15use_local_checkpoints\x18\n \x01(\x08R\x13useLocalCheckpoints\x12U\n\rstorage_level\x18\x0b \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x01R\x0cstorageLevel\x88\x01\x01\x12\x37\n\x16stop_if_all_non_active\x18\x0c \x01(\x08H\x02R\x12stopIfAllNonActive\x88\x01\x01\x12\x66\n\x13initial_active_expr\x18\r \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x03R\x11initialActiveExpr\x88\x01\x01\x12\x64\n\x12update_active_expr\x18\x0e \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x04R\x10updateActiveExpr\x88\x01\x01\x12\x45\n\x1dskip_messages_from_non_active\x18\x0f \x01(\x08H\x05R\x19skipMessagesFromNonActive\x88\x01\x01\x42\x11\n\x0f_early_stoppingB\x10\n\x0e_storage_levelB\x19\n\x17_stop_if_all_non_activeB\x16\n\x14_initial_active_exprB\x15\n\x13_update_active_exprB \n\x1e_skip_messages_from_non_active"\xc8\x02\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks\x12\x1c\n\talgorithm\x18\x02 \x01(\tR\talgorithm\x12\x32\n\x15use_local_checkpoints\x18\x03 \x01(\x08R\x13useLocalCheckpoints\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12U\n\rstorage_level\x18\x05 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"x\n\rTriangleCount\x12U\n\rstorage_level\x18\x01 \x01(\x0b\x32+.org.graphframes.connect.proto.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' ) _globals = globals() @@ -31,43 +31,47 @@ "DESCRIPTOR" ]._serialized_options = b"\n!com.org.graphframes.connect.protoB\020GraphframesProtoH\001P\001\240\001\001\242\002\004OGCP\252\002\035Org.Graphframes.Connect.Proto\312\002\035Org\\Graphframes\\Connect\\Proto\342\002)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\352\002 Org::Graphframes::Connect::Proto" _globals["_GRAPHFRAMESAPI"]._serialized_start = 53 - _globals["_GRAPHFRAMESAPI"]._serialized_end = 1675 - _globals["_COLUMNOREXPRESSION"]._serialized_start = 1677 - _globals["_COLUMNOREXPRESSION"]._serialized_end = 1754 - _globals["_STRINGORLONGID"]._serialized_start = 1756 - _globals["_STRINGORLONGID"]._serialized_end = 1836 - _globals["_AGGREGATEMESSAGES"]._serialized_start = 1839 - _globals["_AGGREGATEMESSAGES"]._serialized_end = 2142 - _globals["_BFS"]._serialized_start = 2145 - _globals["_BFS"]._serialized_end = 2430 - _globals["_CONNECTEDCOMPONENTS"]._serialized_start = 2433 - _globals["_CONNECTEDCOMPONENTS"]._serialized_end = 2639 - _globals["_DROPISOLATEDVERTICES"]._serialized_start = 2641 - _globals["_DROPISOLATEDVERTICES"]._serialized_end = 2663 - _globals["_FILTEREDGES"]._serialized_start = 2665 - _globals["_FILTEREDGES"]._serialized_end = 2759 - _globals["_FILTERVERTICES"]._serialized_start = 2761 - _globals["_FILTERVERTICES"]._serialized_end = 2858 - _globals["_FIND"]._serialized_start = 2860 - _globals["_FIND"]._serialized_end = 2892 - _globals["_LABELPROPAGATION"]._serialized_start = 2894 - _globals["_LABELPROPAGATION"]._serialized_end = 2939 - _globals["_PAGERANK"]._serialized_start = 2942 - _globals["_PAGERANK"]._serialized_end = 3168 - _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_start = 3171 - _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_end = 3351 - _globals["_POWERITERATIONCLUSTERING"]._serialized_start = 3353 - _globals["_POWERITERATIONCLUSTERING"]._serialized_end = 3471 - _globals["_PREGEL"]._serialized_start = 3474 - _globals["_PREGEL"]._serialized_end = 4129 - _globals["_SHORTESTPATHS"]._serialized_start = 4131 - _globals["_SHORTESTPATHS"]._serialized_end = 4223 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4225 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4281 - _globals["_SVDPLUSPLUS"]._serialized_start = 4284 - _globals["_SVDPLUSPLUS"]._serialized_end = 4498 - _globals["_TRIANGLECOUNT"]._serialized_start = 4500 - _globals["_TRIANGLECOUNT"]._serialized_end = 4515 - _globals["_TRIPLETS"]._serialized_start = 4517 - _globals["_TRIPLETS"]._serialized_end = 4527 + _globals["_GRAPHFRAMESAPI"]._serialized_end = 1768 + _globals["_STORAGELEVEL"]._serialized_start = 1771 + _globals["_STORAGELEVEL"]._serialized_end = 2114 + _globals["_COLUMNOREXPRESSION"]._serialized_start = 2116 + _globals["_COLUMNOREXPRESSION"]._serialized_end = 2193 + _globals["_STRINGORLONGID"]._serialized_start = 2195 + _globals["_STRINGORLONGID"]._serialized_end = 2275 + _globals["_AGGREGATEMESSAGES"]._serialized_start = 2278 + _globals["_AGGREGATEMESSAGES"]._serialized_end = 2644 + _globals["_BFS"]._serialized_start = 2647 + _globals["_BFS"]._serialized_end = 2932 + _globals["_CONNECTEDCOMPONENTS"]._serialized_start = 2935 + _globals["_CONNECTEDCOMPONENTS"]._serialized_end = 3325 + _globals["_DETECTINGCYCLES"]._serialized_start = 3328 + _globals["_DETECTINGCYCLES"]._serialized_end = 3551 + _globals["_DROPISOLATEDVERTICES"]._serialized_start = 3553 + _globals["_DROPISOLATEDVERTICES"]._serialized_end = 3575 + _globals["_FILTEREDGES"]._serialized_start = 3577 + _globals["_FILTEREDGES"]._serialized_end = 3671 + _globals["_FILTERVERTICES"]._serialized_start = 3673 + _globals["_FILTERVERTICES"]._serialized_end = 3770 + _globals["_FIND"]._serialized_start = 3772 + _globals["_FIND"]._serialized_end = 3804 + _globals["_LABELPROPAGATION"]._serialized_start = 3807 + _globals["_LABELPROPAGATION"]._serialized_end = 4088 + _globals["_PAGERANK"]._serialized_start = 4091 + _globals["_PAGERANK"]._serialized_end = 4317 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_start = 4320 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_end = 4500 + _globals["_POWERITERATIONCLUSTERING"]._serialized_start = 4502 + _globals["_POWERITERATIONCLUSTERING"]._serialized_end = 4620 + _globals["_PREGEL"]._serialized_start = 4623 + _globals["_PREGEL"]._serialized_end = 5877 + _globals["_SHORTESTPATHS"]._serialized_start = 5880 + _globals["_SHORTESTPATHS"]._serialized_end = 6208 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 6210 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 6266 + _globals["_SVDPLUSPLUS"]._serialized_start = 6269 + _globals["_SVDPLUSPLUS"]._serialized_end = 6483 + _globals["_TRIANGLECOUNT"]._serialized_start = 6485 + _globals["_TRIANGLECOUNT"]._serialized_end = 6605 + _globals["_TRIPLETS"]._serialized_start = 6607 + _globals["_TRIPLETS"]._serialized_end = 6617 # @@protoc_insertion_point(module_scope) diff --git a/python/graphframes/connect/proto/graphframes_pb2.pyi b/python/graphframes/connect/proto/graphframes_pb2.pyi index 26054f34b..ffe59932d 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.pyi +++ b/python/graphframes/connect/proto/graphframes_pb2.pyi @@ -18,6 +18,7 @@ class GraphFramesAPI(_message.Message): "bfs", "connected_components", "drop_isolated_vertices", + "detecting_cycles", "filter_edges", "filter_vertices", "find", @@ -38,6 +39,7 @@ class GraphFramesAPI(_message.Message): BFS_FIELD_NUMBER: _ClassVar[int] CONNECTED_COMPONENTS_FIELD_NUMBER: _ClassVar[int] DROP_ISOLATED_VERTICES_FIELD_NUMBER: _ClassVar[int] + DETECTING_CYCLES_FIELD_NUMBER: _ClassVar[int] FILTER_EDGES_FIELD_NUMBER: _ClassVar[int] FILTER_VERTICES_FIELD_NUMBER: _ClassVar[int] FIND_FIELD_NUMBER: _ClassVar[int] @@ -57,6 +59,7 @@ class GraphFramesAPI(_message.Message): bfs: BFS connected_components: ConnectedComponents drop_isolated_vertices: DropIsolatedVertices + detecting_cycles: DetectingCycles filter_edges: FilterEdges filter_vertices: FilterVertices find: Find @@ -78,6 +81,7 @@ class GraphFramesAPI(_message.Message): bfs: _Optional[_Union[BFS, _Mapping]] = ..., connected_components: _Optional[_Union[ConnectedComponents, _Mapping]] = ..., drop_isolated_vertices: _Optional[_Union[DropIsolatedVertices, _Mapping]] = ..., + detecting_cycles: _Optional[_Union[DetectingCycles, _Mapping]] = ..., filter_edges: _Optional[_Union[FilterEdges, _Mapping]] = ..., filter_vertices: _Optional[_Union[FilterVertices, _Mapping]] = ..., find: _Optional[_Union[Find, _Mapping]] = ..., @@ -97,6 +101,45 @@ class GraphFramesAPI(_message.Message): triplets: _Optional[_Union[Triplets, _Mapping]] = ..., ) -> None: ... +class StorageLevel(_message.Message): + __slots__ = ( + "disk_only", + "disk_only_2", + "disk_only_3", + "memory_and_disk", + "memory_and_disk_2", + "memory_and_disk_deser", + "memory_only", + "memory_only_2", + ) + DISK_ONLY_FIELD_NUMBER: _ClassVar[int] + DISK_ONLY_2_FIELD_NUMBER: _ClassVar[int] + DISK_ONLY_3_FIELD_NUMBER: _ClassVar[int] + MEMORY_AND_DISK_FIELD_NUMBER: _ClassVar[int] + MEMORY_AND_DISK_2_FIELD_NUMBER: _ClassVar[int] + MEMORY_AND_DISK_DESER_FIELD_NUMBER: _ClassVar[int] + MEMORY_ONLY_FIELD_NUMBER: _ClassVar[int] + MEMORY_ONLY_2_FIELD_NUMBER: _ClassVar[int] + disk_only: bool + disk_only_2: bool + disk_only_3: bool + memory_and_disk: bool + memory_and_disk_2: bool + memory_and_disk_deser: bool + memory_only: bool + memory_only_2: bool + def __init__( + self, + disk_only: _Optional[bool] = ..., + disk_only_2: _Optional[bool] = ..., + disk_only_3: _Optional[bool] = ..., + memory_and_disk: _Optional[bool] = ..., + memory_and_disk_2: _Optional[bool] = ..., + memory_and_disk_deser: _Optional[bool] = ..., + memory_only: _Optional[bool] = ..., + memory_only_2: _Optional[bool] = ..., + ) -> None: ... + class ColumnOrExpression(_message.Message): __slots__ = ("col", "expr") COL_FIELD_NUMBER: _ClassVar[int] @@ -114,18 +157,21 @@ class StringOrLongID(_message.Message): def __init__(self, long_id: _Optional[int] = ..., string_id: _Optional[str] = ...) -> None: ... class AggregateMessages(_message.Message): - __slots__ = ("agg_col", "send_to_src", "send_to_dst") + __slots__ = ("agg_col", "send_to_src", "send_to_dst", "storage_level") AGG_COL_FIELD_NUMBER: _ClassVar[int] SEND_TO_SRC_FIELD_NUMBER: _ClassVar[int] SEND_TO_DST_FIELD_NUMBER: _ClassVar[int] - agg_col: ColumnOrExpression - send_to_src: ColumnOrExpression - send_to_dst: ColumnOrExpression + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] + agg_col: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] + send_to_src: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] + send_to_dst: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] + storage_level: StorageLevel def __init__( self, - agg_col: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., - send_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., - send_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + agg_col: _Optional[_Iterable[_Union[ColumnOrExpression, _Mapping]]] = ..., + send_to_src: _Optional[_Iterable[_Union[ColumnOrExpression, _Mapping]]] = ..., + send_to_dst: _Optional[_Iterable[_Union[ColumnOrExpression, _Mapping]]] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., ) -> None: ... class BFS(_message.Message): @@ -152,21 +198,48 @@ class ConnectedComponents(_message.Message): "checkpoint_interval", "broadcast_threshold", "use_labels_as_components", + "use_local_checkpoints", + "max_iter", + "storage_level", ) ALGORITHM_FIELD_NUMBER: _ClassVar[int] CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] BROADCAST_THRESHOLD_FIELD_NUMBER: _ClassVar[int] USE_LABELS_AS_COMPONENTS_FIELD_NUMBER: _ClassVar[int] + USE_LOCAL_CHECKPOINTS_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] algorithm: str checkpoint_interval: int broadcast_threshold: int use_labels_as_components: bool + use_local_checkpoints: bool + max_iter: int + storage_level: StorageLevel def __init__( self, algorithm: _Optional[str] = ..., checkpoint_interval: _Optional[int] = ..., broadcast_threshold: _Optional[int] = ..., - use_labels_as_components: bool = ..., + use_labels_as_components: _Optional[bool] = ..., + use_local_checkpoints: _Optional[bool] = ..., + max_iter: _Optional[int] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., + ) -> None: ... + +class DetectingCycles(_message.Message): + __slots__ = ("use_local_checkpoints", "checkpoint_interval", "storage_level") + USE_LOCAL_CHECKPOINTS_FIELD_NUMBER: _ClassVar[int] + CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] + use_local_checkpoints: bool + checkpoint_interval: int + storage_level: StorageLevel + def __init__( + self, + use_local_checkpoints: _Optional[bool] = ..., + checkpoint_interval: _Optional[int] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., ) -> None: ... class DropIsolatedVertices(_message.Message): @@ -196,10 +269,31 @@ class Find(_message.Message): def __init__(self, pattern: _Optional[str] = ...) -> None: ... class LabelPropagation(_message.Message): - __slots__ = ("max_iter",) + __slots__ = ( + "algorithm", + "max_iter", + "use_local_checkpoints", + "checkpoint_interval", + "storage_level", + ) + ALGORITHM_FIELD_NUMBER: _ClassVar[int] MAX_ITER_FIELD_NUMBER: _ClassVar[int] + USE_LOCAL_CHECKPOINTS_FIELD_NUMBER: _ClassVar[int] + CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] + algorithm: str max_iter: int - def __init__(self, max_iter: _Optional[int] = ...) -> None: ... + use_local_checkpoints: bool + checkpoint_interval: int + storage_level: StorageLevel + def __init__( + self, + algorithm: _Optional[str] = ..., + max_iter: _Optional[int] = ..., + use_local_checkpoints: _Optional[bool] = ..., + checkpoint_interval: _Optional[int] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., + ) -> None: ... class PageRank(_message.Message): __slots__ = ("reset_probability", "source_id", "max_iter", "tol") @@ -260,6 +354,12 @@ class Pregel(_message.Message): "additional_col_initial", "additional_col_upd", "early_stopping", + "use_local_checkpoints", + "storage_level", + "stop_if_all_non_active", + "initial_active_expr", + "update_active_expr", + "skip_messages_from_non_active", ) AGG_MSGS_FIELD_NUMBER: _ClassVar[int] SEND_MSG_TO_DST_FIELD_NUMBER: _ClassVar[int] @@ -270,6 +370,12 @@ class Pregel(_message.Message): ADDITIONAL_COL_INITIAL_FIELD_NUMBER: _ClassVar[int] ADDITIONAL_COL_UPD_FIELD_NUMBER: _ClassVar[int] EARLY_STOPPING_FIELD_NUMBER: _ClassVar[int] + USE_LOCAL_CHECKPOINTS_FIELD_NUMBER: _ClassVar[int] + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] + STOP_IF_ALL_NON_ACTIVE_FIELD_NUMBER: _ClassVar[int] + INITIAL_ACTIVE_EXPR_FIELD_NUMBER: _ClassVar[int] + UPDATE_ACTIVE_EXPR_FIELD_NUMBER: _ClassVar[int] + SKIP_MESSAGES_FROM_NON_ACTIVE_FIELD_NUMBER: _ClassVar[int] agg_msgs: ColumnOrExpression send_msg_to_dst: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] send_msg_to_src: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] @@ -279,6 +385,12 @@ class Pregel(_message.Message): additional_col_initial: ColumnOrExpression additional_col_upd: ColumnOrExpression early_stopping: bool + use_local_checkpoints: bool + storage_level: StorageLevel + stop_if_all_non_active: bool + initial_active_expr: ColumnOrExpression + update_active_expr: ColumnOrExpression + skip_messages_from_non_active: bool def __init__( self, agg_msgs: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., @@ -289,15 +401,40 @@ class Pregel(_message.Message): additional_col_name: _Optional[str] = ..., additional_col_initial: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., additional_col_upd: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., - early_stopping: bool = ..., + early_stopping: _Optional[bool] = ..., + use_local_checkpoints: _Optional[bool] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., + stop_if_all_non_active: _Optional[bool] = ..., + initial_active_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + update_active_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + skip_messages_from_non_active: _Optional[bool] = ..., ) -> None: ... class ShortestPaths(_message.Message): - __slots__ = ("landmarks",) + __slots__ = ( + "landmarks", + "algorithm", + "use_local_checkpoints", + "checkpoint_interval", + "storage_level", + ) LANDMARKS_FIELD_NUMBER: _ClassVar[int] + ALGORITHM_FIELD_NUMBER: _ClassVar[int] + USE_LOCAL_CHECKPOINTS_FIELD_NUMBER: _ClassVar[int] + CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] landmarks: _containers.RepeatedCompositeFieldContainer[StringOrLongID] + algorithm: str + use_local_checkpoints: bool + checkpoint_interval: int + storage_level: StorageLevel def __init__( - self, landmarks: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ... + self, + landmarks: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ..., + algorithm: _Optional[str] = ..., + use_local_checkpoints: _Optional[bool] = ..., + checkpoint_interval: _Optional[int] = ..., + storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ..., ) -> None: ... class StronglyConnectedComponents(_message.Message): @@ -346,8 +483,10 @@ class SVDPlusPlus(_message.Message): ) -> None: ... class TriangleCount(_message.Message): - __slots__ = () - def __init__(self) -> None: ... + __slots__ = ("storage_level",) + STORAGE_LEVEL_FIELD_NUMBER: _ClassVar[int] + storage_level: StorageLevel + def __init__(self, storage_level: _Optional[_Union[StorageLevel, _Mapping]] = ...) -> None: ... class Triplets(_message.Message): __slots__ = () diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py index 77152137e..9def22624 100644 --- a/python/graphframes/connect/utils.py +++ b/python/graphframes/connect/utils.py @@ -5,8 +5,15 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.plan import LogicalPlan +from pyspark.storagelevel import StorageLevel -from .proto.graphframes_pb2 import ColumnOrExpression, StringOrLongID +from .proto.graphframes_pb2 import ( + ColumnOrExpression, +) +from .proto.graphframes_pb2 import StorageLevel as StorageLevelProto +from .proto.graphframes_pb2 import ( + StringOrLongID, +) def dataframe_to_proto(df: DataFrame, client: SparkConnectClient) -> bytes: @@ -35,3 +42,24 @@ def make_str_or_long_id(str_or_long: str | int) -> StringOrLongID: return StringOrLongID(string_id=str_or_long) else: return StringOrLongID(long_id=str_or_long) + + +def storage_level_to_proto(storage_level: StorageLevel) -> StorageLevelProto: + if storage_level == StorageLevel.DISK_ONLY: + return StorageLevelProto(disk_only=True) + elif storage_level == StorageLevel.DISK_ONLY_2: + return StorageLevelProto(disk_only_2=True) + elif storage_level == StorageLevel.DISK_ONLY_3: + return StorageLevelProto(disk_only_3=True) + elif storage_level == StorageLevel.MEMORY_AND_DISK: + return StorageLevelProto(memory_and_disk=True) + elif storage_level == StorageLevel.MEMORY_AND_DISK_2: + return StorageLevelProto(memory_and_disk_2=True) + elif storage_level == StorageLevel.MEMORY_ONLY: + return StorageLevelProto(memory_only=True) + elif storage_level == StorageLevel.MEMORY_ONLY_2: + return StorageLevelProto(memory_only_2=True) + elif storage_level == StorageLevel.MEMORY_AND_DISK_DESER: + return StorageLevelProto(memory_and_disk_deser=True) + else: + raise ValueError(f"Unknown storage level: {storage_level}") diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 53c58ecb1..b32014ae9 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -17,10 +17,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +import warnings +from typing import TYPE_CHECKING, Any from pyspark.storagelevel import StorageLevel from pyspark.version import __version__ +from typing_extensions import override if __version__[:3] >= "3.4": from pyspark.sql.utils import is_remote @@ -89,6 +91,12 @@ def edges(self) -> DataFrame: """ return self._impl.edges + @property + def nodes(self) -> DataFrame: + """Alias to vertices.""" + return self.vertices + + @override def __repr__(self) -> str: return self._impl.__repr__() @@ -233,9 +241,10 @@ def bfs( def aggregateMessages( self, - aggCol: Column | str, - sendToSrc: Column | str | None = None, - sendToDst: Column | str | None = None, + aggCol: list[Column | str] | Column, + sendToSrc: list[Column | str] | Column | str = list(), + sendToDst: list[Column | str] | Column | str = list(), + intermediate_storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK, ) -> DataFrame: """ Aggregates messages from the neighbours. @@ -245,16 +254,59 @@ def aggregateMessages( See Scala documentation for more details. - :param aggCol: the requested aggregation output either as + Warning! The result of this method is persisted DataFrame object! Users should handle unpersist + to avoid possible memory leaks! + + :param aggCol: the requested aggregation output either as a collection of :class:`pyspark.sql.Column` or SQL expression string :param sendToSrc: message sent to the source vertex of each triplet either as - :class:`pyspark.sql.Column` or SQL expression string (default: None) + a collection of :class:`pyspark.sql.Column` or SQL expression string (default: None) :param sendToDst: message sent to the destination vertex of each triplet either as - :class:`pyspark.sql.Column` or SQL expression string (default: None) + collection of :class:`pyspark.sql.Column` or SQL expression string (default: None) + :param intermediate_storage_level: the level of intermediate storage that will be used + for both intermediate result and the output. - :return: DataFrame with columns for the vertex ID and the resulting aggregated message - """ - return self._impl.aggregateMessages(aggCol=aggCol, sendToSrc=sendToSrc, sendToDst=sendToDst) + :return: DataFrame with columns for the vertex ID and the resulting aggregated message. + The name of the resulted message column is based on the alias of the provided aggCol! + """ # noqa: E501 + + # Back-compatibility workaround + if not isinstance(aggCol, list): + warnings.warn( + "Passing single column to aggCol is deprecated, use list", + DeprecationWarning, + ) + return self.aggregateMessages( + [aggCol], sendToSrc, sendToDst, intermediate_storage_level + ) + if not isinstance(sendToSrc, list): + warnings.warn( + "Passing single column to sendToSrc is deprecated, use list", + DeprecationWarning, + ) + return self.aggregateMessages( + aggCol, [sendToSrc], sendToDst, intermediate_storage_level + ) + if not isinstance(sendToDst, list): + warnings.warn( + "Passing single column to sendToDst is deprecated, use list", + DeprecationWarning, + ) + return self.aggregateMessages( + aggCol, sendToSrc, [sendToDst], intermediate_storage_level + ) + + if len(aggCol) == 0: + raise TypeError("At least one aggregation column should be provided!") + + if (len(sendToSrc) == 0) and (len(sendToDst) == 0): + raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") + return self._impl.aggregateMessages( + aggCol=aggCol, + sendToSrc=sendToSrc, + sendToDst=sendToDst, + intermediate_storage_level=intermediate_storage_level, + ) # Standard algorithms @@ -301,9 +353,9 @@ def labelPropagation(self, maxIter: int) -> DataFrame: def pageRank( self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None, + sourceId: Any | None = None, + maxIter: int | None = None, + tol: float | None = None, ) -> "GraphFrame": """ Runs the PageRank algorithm on the graph. @@ -331,8 +383,8 @@ def pageRank( def parallelPersonalizedPageRank( self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None, + sourceIds: list[Any] | None = None, + maxIter: int | None = None, ) -> "GraphFrame": """ Run the personalized PageRank algorithm on the graph, @@ -413,7 +465,7 @@ def triangleCount(self) -> DataFrame: return self._impl.triangleCount() def powerIterationClustering( - self, k: int, maxIter: int, weightCol: Optional[str] = None + self, k: int, maxIter: int, weightCol: str | None = None ) -> DataFrame: """ Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen. diff --git a/python/graphframes/lib/__init__.py b/python/graphframes/lib/__init__.py deleted file mode 100644 index 076dd5232..000000000 --- a/python/graphframes/lib/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .aggregate_messages import AggregateMessages -from .pregel import Pregel - -__all__ = ["AggregateMessages", "Pregel"] From 3aed3979082e102f708c074f51ceb33b8ab923d1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 1 Oct 2025 07:58:19 +0200 Subject: [PATCH 06/17] WIP --- .../graphframes/classic/aggregate_messages.py | 74 --- python/graphframes/classic/graphframe.py | 188 +++++--- python/graphframes/classic/utils.py | 4 +- ...hframe_client.py => graphframes_client.py} | 443 +++++++++++------- python/graphframes/graphframe.py | 188 ++++++-- python/graphframes/lib/__init__.py | 4 + python/graphframes/lib/aggregate_messages.py | 46 ++ python/graphframes/{classic => lib}/pregel.py | 173 ++++--- 8 files changed, 703 insertions(+), 417 deletions(-) delete mode 100644 python/graphframes/classic/aggregate_messages.py rename python/graphframes/connect/{graphframe_client.py => graphframes_client.py} (75%) create mode 100644 python/graphframes/lib/__init__.py create mode 100644 python/graphframes/lib/aggregate_messages.py rename python/graphframes/{classic => lib}/pregel.py (58%) diff --git a/python/graphframes/classic/aggregate_messages.py b/python/graphframes/classic/aggregate_messages.py deleted file mode 100644 index a14618288..000000000 --- a/python/graphframes/classic/aggregate_messages.py +++ /dev/null @@ -1,74 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any - -from pyspark import SparkContext -from pyspark.sql import Column -from pyspark.sql import functions as sqlfunctions - - -def _java_api(jsc: SparkContext) -> Any: - javaClassName = "org.graphframes.GraphFramePythonAPI" - return ( - jsc._jvm.Thread.currentThread() - .getContextClassLoader() - .loadClass(javaClassName) - .newInstance() - ) - - -class _ClassProperty: - """Custom read-only class property descriptor. - - The underlying method should take the class as the sole argument. - """ - - def __init__(self, f: callable) -> None: - self.f = f - self.__doc__ = f.__doc__ - - def __get__(self, instance: Any, owner: type) -> Any: - return self.f(owner) - - -class AggregateMessages: - """Collection of utilities usable with :meth:`graphframes.GraphFrame.aggregateMessages()`.""" - - @_ClassProperty - def src(cls) -> Column: - """Reference for source column, used for specifying messages.""" - jvm_gf_api = _java_api(SparkContext) - return sqlfunctions.col(jvm_gf_api.SRC()) - - @_ClassProperty - def dst(cls) -> Column: - """Reference for destination column, used for specifying messages.""" - jvm_gf_api = _java_api(SparkContext) - return sqlfunctions.col(jvm_gf_api.DST()) - - @_ClassProperty - def edge(cls) -> Column: - """Reference for edge column, used for specifying messages.""" - jvm_gf_api = _java_api(SparkContext) - return sqlfunctions.col(jvm_gf_api.EDGE()) - - @_ClassProperty - def msg(cls) -> Column: - """Reference for message column, used for specifying aggregation function.""" - jvm_gf_api = _java_api(SparkContext) - return sqlfunctions.col(jvm_gf_api.aggregateMessages().MSG_COL_NAME()) diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index 11901d010..73cf2ab16 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -16,21 +16,21 @@ # from __future__ import annotations -import sys -from typing import Any, Optional, Union +from typing import final -if sys.version > "3": - basestring = str +from py4j.java_gateway import JavaObject from pyspark.sql.classic.column import Column, _to_seq -from pyspark.sql.classic.dataframe import DataFrame, SparkContext, SparkSession +from pyspark.sql.classic.dataframe import DataFrame +from pyspark.sql import SparkSession +from pyspark.core.context import SparkContext from pyspark.storagelevel import StorageLevel -from graphframes.classic.pregel import Pregel +from graphframes.lib import Pregel from graphframes.classic.utils import storage_level_to_jvm -def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": +def _from_java_gf(jgf: JavaObject, spark: SparkSession) -> "GraphFrame": """ (internal) creates a python GraphFrame wrapper from a java GraphFrame. @@ -41,7 +41,7 @@ def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": return GraphFrame(pv, pe) -def _java_api(jsc: SparkContext) -> Any: +def _java_api(jsc: SparkContext) -> JavaObject: javaClassName = "org.graphframes.GraphFramePythonAPI" if jsc._jvm is None: raise RuntimeError( @@ -55,6 +55,7 @@ def _java_api(jsc: SparkContext) -> Any: ) +@final class GraphFrame: def __init__(self, v: DataFrame, e: DataFrame) -> None: self._vertices = v @@ -63,10 +64,10 @@ def __init__(self, v: DataFrame, e: DataFrame) -> None: self._sc = self._spark._sc self._jvm_gf_api = _java_api(self._sc) - self.ID = self._jvm_gf_api.ID() - self.SRC = self._jvm_gf_api.SRC() - self.DST = self._jvm_gf_api.DST() - self._ATTR = self._jvm_gf_api.ATTR() + self.ID: str = "id" + self.SRC: str = "src" + self.DST: str = "edge" + self._ATTR: str = self._jvm_gf_api.ATTR() # Check that provided DataFrames contain required columns if self.ID not in v.columns: @@ -98,14 +99,16 @@ def vertices(self) -> DataFrame: def edges(self) -> DataFrame: return self._edges - def __repr__(self): + def __repr__(self) -> str: return self._jvm_graph.toString() def cache(self) -> "GraphFrame": self._jvm_graph.cache() return self - def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": + def persist( + self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + ) -> "GraphFrame": javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jvm_graph.persist(javaStorageLevel) return self @@ -142,22 +145,38 @@ def find(self, pattern: str) -> DataFrame: jdf = self._jvm_graph.find(pattern) return DataFrame(jdf, self._spark) - def filterVertices(self, condition: Union[str, Column]) -> "GraphFrame": - if isinstance(condition, basestring): + def filterVertices(self, condition: str | Column) -> "GraphFrame": + if isinstance(condition, str): jdf = self._jvm_graph.filterVertices(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterVertices(condition._jc) else: - raise TypeError("condition should be string or Column") + jdf = self._jvm_graph.filterVertices(condition._jc) + return _from_java_gf(jdf, self._spark) - def filterEdges(self, condition: Union[str, Column]) -> "GraphFrame": - if isinstance(condition, basestring): + def filterEdges(self, condition: str | Column) -> "GraphFrame": + if isinstance(condition, str): jdf = self._jvm_graph.filterEdges(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterEdges(condition._jc) else: - raise TypeError("condition should be string or Column") + jdf = self._jvm_graph.filterEdges(condition._jc) + + return _from_java_gf(jdf, self._spark) + + def detectingCycles( + self, + checkpoint_interval: int = 2, + use_local_checkpoints: bool = False, + intermediate_storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, + ) -> DataFrame: + jdf = ( + self._jvm_graph.detectingCycles() + .setUseLocalCheckpoints(use_local_checkpoints) + .setCheckpointInterval(checkpoint_interval) + .setIntermediateStorageLevel( + storage_level_to_jvm(intermediate_storage_level, self._spark) + ) + .run() + ) + return _from_java_gf(jdf, self._spark) def dropIsolatedVertices(self) -> "GraphFrame": @@ -168,11 +187,14 @@ def bfs( self, fromExpr: str, toExpr: str, - edgeFilter: Optional[str] = None, + edgeFilter: str | None = None, maxPathLength: int = 10, ) -> DataFrame: builder = ( - self._jvm_graph.bfs().fromExpr(fromExpr).toExpr(toExpr).maxPathLength(maxPathLength) + self._jvm_graph.bfs() + .fromExpr(fromExpr) + .toExpr(toExpr) + .maxPathLength(maxPathLength) ) if edgeFilter is not None: builder.edgeFilter(edgeFilter) @@ -193,7 +215,7 @@ def aggregateMessages( if len(sendToSrc) == 1: if isinstance(sendToSrc[0], Column): builder.sendToSrc(sendToSrc[0]._jc) - elif isinstance(sendToSrc[0], basestring): + elif isinstance(sendToSrc[0], str): builder.sendToSrc(sendToSrc[0]) else: raise TypeError("Provide message either as `Column` or `str`") @@ -201,7 +223,7 @@ def aggregateMessages( if all(isinstance(x, Column) for x in sendToSrc): send2src = [x._jc for x in sendToSrc] builder.sendToSrc(send2src[0], _to_seq(self._sc, send2src[1:])) - elif all(isinstance(x, basestring) for x in sendToSrc): + elif all(isinstance(x, str) for x in sendToSrc): builder.sendToSrc(sendToSrc[0], _to_seq(self._sc, sendToSrc[1:])) else: raise TypeError( @@ -211,7 +233,7 @@ def aggregateMessages( if len(sendToDst) == 1: if isinstance(sendToDst[0], Column): builder.sendToDst(sendToDst[0]._jc) - elif isinstance(sendToDst[0], basestring): + elif isinstance(sendToDst[0], str): builder.sendToDst(sendToDst[0]) else: raise TypeError("Provide message either as `Column` or `str`") @@ -219,7 +241,7 @@ def aggregateMessages( if all(isinstance(x, Column) for x in sendToDst): send2dst = [x._jc for x in sendToDst] builder.sendToDst(send2dst[0], _to_seq(self._sc, send2dst[1:])) - elif all(isinstance(x, basestring) for x in sendToDst): + elif all(isinstance(x, str) for x in sendToDst): builder.sendToDst(sendToDst[0], _to_seq(self._sc, sendToDst[1:])) else: raise TypeError( @@ -229,13 +251,15 @@ def aggregateMessages( if len(aggCol) == 1: if isinstance(aggCol[0], Column): jdf = builder.aggCol(aggCol[0]._jc) - elif isinstance(aggCol[0], basestring): + elif isinstance(aggCol[0], str): jdf = builder.aggCol(aggCol[0]) elif len(aggCol) > 1: if all(isinstance(x, Column) for x in aggCol): - jdf = builder.aggCol(aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol])) - elif all(isinstance(x, basestring) for x in aggCol): - jdf = builder.aggCol(aggCol[0], _to_seq(self._sc, aggCol)) + jdf = builder.aggCol( + aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol]) + ) + elif all(isinstance(x, str) for x in aggCol): + jdf = builder.aggCol(aggCol[0], _to_seq(self._sc, aggCol[1:])) else: raise TypeError( "Multiple agg cols should all be `Column` or `str`, not a mix of them." @@ -244,10 +268,13 @@ def aggregateMessages( def connectedComponents( self, - algorithm: str = "graphframes", - checkpointInterval: int = 2, - broadcastThreshold: int = 1000000, - useLabelsAsComponents: bool = False, + algorithm: str, + checkpointInterval: int, + broadcastThreshold: int, + useLabelsAsComponents: bool, + use_local_checkpoints: bool, + max_iter: int, + storage_level: StorageLevel, ) -> DataFrame: jdf = ( self._jvm_graph.connectedComponents() @@ -255,20 +282,42 @@ def connectedComponents( .setCheckpointInterval(checkpointInterval) .setBroadcastThreshold(broadcastThreshold) .setUseLabelsAsComponents(useLabelsAsComponents) + .setUseLocalCheckpoints(use_local_checkpoints) + .maxIter(max_iter) + .setIntermediateStorageLevel( + storage_level_to_jvm(storage_level, self._spark) + ) .run() ) return DataFrame(jdf, self._spark) - def labelPropagation(self, maxIter: int) -> DataFrame: - jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() + def labelPropagation( + self, + maxIter: int, + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> DataFrame: + jdf = ( + self._jvm_graph.labelPropagation() + .maxIter(maxIter) + .setAlgorithm(algorithm) + .setUseLocalCheckpoints(use_local_checkpoints) + .setCheckpointInterval(checkpoint_interval) + .setIntermediateStorageLevel( + storage_level_to_jvm(storage_level, self._spark) + ) + .run() + ) return DataFrame(jdf, self._spark) def pageRank( self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None, + sourceId: str | int | None = None, + maxIter: int | None = None, + tol: float | None = None, ) -> "GraphFrame": builder = self._jvm_graph.pageRank().resetProbability(resetProbability) if sourceId is not None: @@ -285,12 +334,12 @@ def pageRank( def parallelPersonalizedPageRank( self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None, + sourceIds: list[str | int] | None = None, + maxIter: int | None = None, ) -> "GraphFrame": - assert ( - sourceIds is not None and len(sourceIds) > 0 - ), "Source vertices Ids sourceIds must be provided" + assert sourceIds is not None and len(sourceIds) > 0, ( + "Source vertices Ids sourceIds must be provided" + ) assert maxIter is not None, "Max number of iterations maxIter must be provided" sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) builder = self._jvm_graph.parallelPersonalizedPageRank() @@ -300,8 +349,25 @@ def parallelPersonalizedPageRank( jgf = builder.run() return _from_java_gf(jgf, self._spark) - def shortestPaths(self, landmarks: list[Any]) -> DataFrame: - jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() + def shortestPaths( + self, + landmarks: list[str | int], + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> DataFrame: + jdf = ( + self._jvm_graph.shortestPaths() + .landmarks(landmarks) + .setAlgorithm(algorithm) + .setUseLocalCheckpoints(use_local_checkpoints) + .setCheckpointInterval(checkpoint_interval) + .setIntermediateStorageLevel( + storage_level_to_jvm(storage_level, self._spark) + ) + .run() + ) return DataFrame(jdf, self._spark) def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: @@ -333,7 +399,7 @@ def triangleCount(self) -> DataFrame: return DataFrame(jdf, self._spark) def powerIterationClustering( - self, k: int, maxIter: int, weightCol: Optional[str] = None + self, k: int, maxIter: int, weightCol: str | None = None ) -> DataFrame: if weightCol: weightCol = self._spark._jvm.scala.Option.apply(weightCol) @@ -341,23 +407,3 @@ def powerIterationClustering( weightCol = self._spark._jvm.scala.Option.empty() jdf = self._jvm_graph.powerIterationClustering(k, maxIter, weightCol) return DataFrame(jdf, self._spark) - - -def _test(): - import doctest - - import graphframe - - globs = graphframe.__dict__.copy() - globs["sc"] = SparkContext("local[4]", "PythonTest", batchSize=2) - globs["spark"] = SparkSession(globs["sc"]).builder.getOrCreate() - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE - ) - globs["sc"].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/graphframes/classic/utils.py b/python/graphframes/classic/utils.py index 863ca0c5f..b6b39ddea 100644 --- a/python/graphframes/classic/utils.py +++ b/python/graphframes/classic/utils.py @@ -6,7 +6,9 @@ from pyspark.sql.classic.dataframe import SparkSession -def storage_level_to_jvm(storage_level: StorageLevel, spark: SparkSession) -> JavaObject: +def storage_level_to_jvm( + storage_level: StorageLevel, spark: SparkSession +) -> JavaObject: return spark._jvm.org.apache.spark.storage.StorageLevel.apply( storage_level.useDisk, storage_level.useMemory, diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframes_client.py similarity index 75% rename from python/graphframes/connect/graphframe_client.py rename to python/graphframes/connect/graphframes_client.py index a9d7671d8..68ba8d00e 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframes_client.py @@ -1,4 +1,6 @@ from __future__ import annotations +from typing import final +from typing_extensions import override from pyspark.sql.connect import functions as F from pyspark.sql.connect import proto @@ -35,32 +37,8 @@ def _dataframe_from_plan(plan: LogicalPlan, session: SparkSession) -> DataFrame: return DataFrame(plan, session) +@final class PregelConnect: - """Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. - - See `Malewicz et al., Pregel: a system for large-scale graph processing `_ - for a detailed description of the Pregel algorithm. - - You can construct a Pregel instance using either this constructor or :attr:`graphframes.GraphFrame.pregel`, - then use builder pattern to describe the operations, and then call :func:`run` to start a run. - It returns a DataFrame of vertices from the last iteration. - - When a run starts, it expands the vertices DataFrame using column expressions defined by :func:`withVertexColumn`. - Those additional vertex properties can be changed during Pregel iterations. - In each Pregel iteration, there are three phases: - - Given each edge triplet, generate messages and specify target vertices to send, - described by :func:`sendMsgToDst` and :func:`sendMsgToSrc`. - - Aggregate messages by target vertex IDs, described by :func:`aggMsgs`. - - Update additional vertex properties based on aggregated messages and states from previous iteration, - described by :func:`withVertexColumn`. - - Please find what columns you can reference at each phase in the method API docs. - - You can control the number of iterations by :func:`setMaxIter` and check API docs for advanced controls. - - :param graph: a :class:`graphframes.GraphFrame` object holding a graph with vertices and edges stored as DataFrames. - """ # noqa: E501 - def __init__(self, graph: "GraphFrameConnect") -> None: self.graph = graph self._max_iter = 10 @@ -68,8 +46,8 @@ def __init__(self, graph: "GraphFrameConnect") -> None: self._col_name = None self._initial_expr = None self._update_after_agg_msgs_expr = None - self._send_msg_to_src = [] - self._send_msg_to_dst = [] + self._send_msg_to_src: list[Column | str] = [] + self._send_msg_to_dst: list[Column | str] = [] self._agg_msg = None self._early_stopping = False self._use_local_checkpoints = False @@ -80,34 +58,14 @@ def __init__(self, graph: "GraphFrameConnect") -> None: self._skip_messages_from_non_active = False def setMaxIter(self, value: int) -> Self: - """Sets the max number of iterations (default: 2).""" self._max_iter = value return self def setCheckpointInterval(self, value: int) -> Self: - """Sets the number of iterations between two checkpoints (default: 2). - - This is an advanced control to balance query plan optimization and checkpoint data I/O cost. - In most cases, you should keep the default value. - - Checkpoint is disabled if this is set to 0. - """ self._checkpoint_interval = value return self def setEarlyStopping(self, value: bool) -> Self: - """Set should Pregel stop earlier in case of no new messages to send or not. - - Early stopping allows to terminate Pregel before reaching maxIter by checking if there are any non-null messages. - While in some cases it may gain significant performance boost, in other cases it can lead to performance degradation, - because checking if the messages DataFrame is empty or not is an action and requires materialization of the Spark Plan - with some additional computations. - - In the case when the user can assume a good value of maxIter, it is recommended to leave this value to the default "false". - In the case when it is hard to estimate the number of iterations required for convergence, - it is recommended to set this value to "false" to avoid iterating over convergence until reaching maxIter. - When this value is "true", maxIter can be set to a bigger value without risks. - """ # noqa: E501 self._early_stopping = value return self @@ -117,139 +75,49 @@ def withVertexColumn( initialExpr: Column | str, updateAfterAggMsgsExpr: Column | str, ) -> Self: - """Defines an additional vertex column at the start of run and how to update it in each iteration. - - You can call it multiple times to add more than one additional vertex columns. - - :param colName: the name of the additional vertex column. - It cannot be an existing vertex column in the graph. - :param initialExpr: the expression to initialize the additional vertex column. - You can reference all original vertex columns in this expression. - :param updateAfterAggMsgsExpr: the expression to update the additional vertex column after messages aggregation. - You can reference all original vertex columns, additional vertex columns, and the - aggregated message column using :func:`msg`. - If the vertex received no messages, the message column would be null. - """ # noqa: E501 self._col_name = colName self._initial_expr = initialExpr self._update_after_agg_msgs_expr = updateAfterAggMsgsExpr return self def sendMsgToSrc(self, msgExpr: Column | str) -> Self: - """Defines a message to send to the source vertex of each edge triplet. - - You can call it multiple times to send more than one messages. - - See method :func:`sendMsgToDst`. - - :param msgExpr: the expression of the message to send to the source vertex given a (src, edge, dst) triplet. - Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, - and `edge`, respectively. - You can reference them using :func:`src`, :func:`dst`, and :func:`edge`. - Null messages are not included in message aggregation. - """ # noqa: E501 self._send_msg_to_src.append(msgExpr) return self def sendMsgToDst(self, msgExpr: Column | str) -> Self: - """Defines a message to send to the destination vertex of each edge triplet. - - You can call it multiple times to send more than one messages. - - See method :func:`sendMsgToSrc`. - - :param msgExpr: the message expression to send to the destination vertex given a (`src`, `edge`, `dst`) triplet. - Source/destination vertex properties and edge properties are nested under columns `src`, `dst`, - and `edge`, respectively. - You can reference them using :func:`src`, :func:`dst`, and :func:`edge`. - Null messages are not included in message aggregation. - """ # noqa: E501 self._send_msg_to_dst.append(msgExpr) return self def aggMsgs(self, aggExpr: Column) -> Self: - """Defines how messages are aggregated after grouped by target vertex IDs. - - :param aggExpr: the message aggregation expression, such as `sum(Pregel.msg())`. - You can reference the message column by :func:`msg` and the vertex ID by `col("id")`, - while the latter is usually not used. - """ # noqa: E501 self._agg_msg = aggExpr return self def setStopIfAllNonActiveVertices(self, value: bool) -> Self: - """Set should Pregel stop if all the vertices voted to halt. - - Activity (or vote) is determined based on the activity_col. - See methods :func:`setInitialActiveVertexExpression` and :func:`setUpdateActiveVertexExpression` for details - how to set and update activity_col. - - Be aware that checking of the vote is not free but a Spark Action. In case the - condition is not realistically reachable but set, it will just slow down the algorithm. - - :param value: the boolean value. - """ # noqa: E501 self._stop_if_all_non_active = value return self def setInitialActiveVertexExpression(self, value: Column | str) -> Self: - """Sets the initial expression for the active vertex column. - - The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. - This expression is evaluated on the initial vertices DataFrame to set the initial state of the activity column. - - :param value: expression to compute the initial active state of vertices. - You can reference all original vertex columns in this expression. - """ # noqa: E501 self._initial_active_expr = value return self def setUpdateActiveVertexExpression(self, value: Column | str) -> Self: - """Sets the expression to update the active vertex column. - - The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. - This expression is evaluated on the updated vertices DataFrame to set the new state of the activity column. - - :param value: expression to compute the new active state of vertices. - You can reference all original vertex columns and additional vertex columns in this expression. - """ # noqa: E501 self._update_active_expr = value return self def setSkipMessagesFromNonActiveVertices(self, value: bool) -> Self: - """Set should Pregel skip sending messages from non-active vertices. - - When this option is enabled, messages will not be sent from vertices that are marked as inactive. - This can help optimize performance by avoiding unnecessary message propagation from inactive vertices. - - :param value: boolean value. - """ # noqa: E501 self._skip_messages_from_non_active = value return self def setUseLocalCheckpoints(self, value: bool) -> Self: - """Set should Pregel use local checkpoints. - - Local checkpoints are faster and do not require configuring a persistent storage. - At the same time, local checkpoints are less reliable and may create a big load on local disks of executors. - - :param value: boolean value. - """ # noqa: E501 self._use_local_checkpoints = value return self def setIntermediateStorageLevel(self, storage_level: StorageLevel) -> Self: - """Set the intermediate storage level. - On each iteration, Pregel cache results with a requested storage level. - - For very big graphs it is recommended to use DISK_ONLY. - - :param storage_level: storage level to use. - """ # noqa: E501 self._storage_level = storage_level return self def run(self) -> DataFrame: + @final class Pregel(LogicalPlan): def __init__( self, @@ -290,6 +158,7 @@ def __init__( self.vertices = vertices self.edges = edges + @override def plan(self, session: SparkConnectClient) -> proto.Relation: pregel = pb.Pregel( agg_msgs=make_column_or_expr(self.agg_msg, session), @@ -302,17 +171,25 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: checkpoint_interval=self.checkpoint_interval, max_iter=self.max_iter, additional_col_name=self.vertex_col_name, - additional_col_initial=make_column_or_expr(self.vertex_col_init, session), - additional_col_upd=make_column_or_expr(self.vertex_col_upd, session), + additional_col_initial=make_column_or_expr( + self.vertex_col_init, session + ), + additional_col_upd=make_column_or_expr( + self.vertex_col_upd, session + ), early_stopping=self.early_stopping, use_local_checkpoints=self.use_local_checkpoints, storage_level=storage_level_to_proto(self.storage_level), stop_if_all_non_active=self.stop_if_all_non_active, skip_messages_from_non_active=self.skip_message_from_non_active, - initial_active_expr=make_column_or_expr(self.initial_active_expr, session) + initial_active_expr=make_column_or_expr( + self.initial_active_expr, session + ) if self.initial_active_expr is not None else None, - update_active_expr=make_column_or_expr(self.update_active_expr, session) + update_active_expr=make_column_or_expr( + self.update_active_expr, session + ) if self.update_active_expr is not None else None, ) @@ -375,11 +252,12 @@ def edge(colName: str) -> Column: return F.col("edge." + colName) +@final class GraphFrameConnect: - ID = "id" - SRC = "src" - DST = "dst" - EDGE = "edge" + ID: str = "id" + SRC: str = "src" + DST: str = "dst" + EDGE: str = "edge" def __init__(self, v: DataFrame, e: DataFrame) -> None: self._vertices = v @@ -422,6 +300,7 @@ def vertices(self) -> DataFrame: def edges(self) -> DataFrame: return self._edges + @override def __repr__(self) -> str: # Exactly like in the scala core v_cols = [self.ID] + [col for col in self.vertices.columns if col != self.ID] @@ -438,7 +317,9 @@ def cache(self) -> "GraphFrameConnect": new_edges = self._edges.cache() return GraphFrameConnect(new_vertices, new_edges) - def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrameConnect": + def persist( + self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + ) -> "GraphFrameConnect": new_vertices = self._vertices.persist(storageLevel=storageLevel) new_edges = self._edges.persist(storageLevel=storageLevel) return GraphFrameConnect(new_vertices, new_edges) @@ -463,19 +344,23 @@ def inDegrees(self) -> DataFrame: @property def degrees(self) -> DataFrame: return ( - self._edges.select(F.explode(F.array(F.col(self.SRC), F.col(self.DST))).alias(self.ID)) + self._edges.select( + F.explode(F.array(F.col(self.SRC), F.col(self.DST))).alias(self.ID) + ) .groupBy(self.ID) .agg(F.count("*").alias("degree")) ) @property def triplets(self) -> DataFrame: + @final class Triplets(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: super().__init__(None) self.v = v self.e = e + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -492,6 +377,7 @@ def pregel(self): return PregelConnect(self) def find(self, pattern: str) -> DataFrame: + @final class Find(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, pattern: str) -> None: super().__init__(None) @@ -499,6 +385,7 @@ def __init__(self, v: DataFrame, e: DataFrame, pattern: str) -> None: self.e = e self.p = pattern + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -508,16 +395,22 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - return _dataframe_from_plan(Find(self._vertices, self._edges, pattern), self._spark) + return _dataframe_from_plan( + Find(self._vertices, self._edges, pattern), self._spark + ) def filterVertices(self, condition: str | Column) -> "GraphFrameConnect": + @final class FilterVertices(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: + def __init__( + self, v: DataFrame, e: DataFrame, condition: str | Column + ) -> None: super().__init__(None) self.v = v self.e = e self.c = condition + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -546,19 +439,25 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return GraphFrameConnect(new_vertices, new_edges) def filterEdges(self, condition: str | Column) -> "GraphFrameConnect": + @final class FilterEdges(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: + def __init__( + self, v: DataFrame, e: DataFrame, condition: str | Column + ) -> None: super().__init__(None) self.v = v self.e = e self.c = condition + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) col_or_expr = make_column_or_expr(self.c, session) - graphframes_api_call.filter_edges.CopyFrom(pb.FilterEdges(condition=col_or_expr)) + graphframes_api_call.filter_edges.CopyFrom( + pb.FilterEdges(condition=col_or_expr) + ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -568,18 +467,72 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) return GraphFrameConnect(self._vertices, new_edges) + def detectingCycles( + self, + checkpoint_interval: int, + use_local_checkpoints: bool, + intermediate_storage_level: StorageLevel, + ) -> DataFrame: + @final + class DetectingCycles(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + checkpoint_interval: int, + use_local_checkpoints: bool, + storage_level: StorageLevel, + ) -> None: + super().__init__(None) + self.v = v + self.e = e + self.checkpoint_interval = checkpoint_interval + self.use_local_checkpoints = use_local_checkpoints + self.storage_level = storage_level + + @override + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.detecting_cycles.CopyFrom( + pb.DetectingCycles( + use_local_checkpoints=self.use_local_checkpoints, + checkpoint_interval=self.checkpoint_interval, + storage_level=storage_level_to_proto(self.storage_level), + ) + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return _dataframe_from_plan( + DetectingCycles( + self._vertices, + self._edges, + checkpoint_interval, + use_local_checkpoints, + intermediate_storage_level, + ), + self._spark, + ) + def dropIsolatedVertices(self) -> "GraphFrameConnect": + @final class DropIsolatedVertices(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: super().__init__(None) self.v = v self.e = e + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.drop_isolated_vertices.CopyFrom(pb.DropIsolatedVertices()) + graphframes_api_call.drop_isolated_vertices.CopyFrom( + pb.DropIsolatedVertices() + ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -596,6 +549,7 @@ def bfs( edgeFilter: Column | str | None = None, maxPathLength: int = 10, ) -> DataFrame: + @final class BFS(LogicalPlan): def __init__( self, @@ -614,6 +568,7 @@ def __init__( self.edge_filter = edge_filter self.max_path_len = max_path_len + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -652,6 +607,7 @@ def aggregateMessages( sendToDst: list[Column | str], intermediate_storage_level: StorageLevel, ) -> DataFrame: + @final class AggregateMessages(LogicalPlan): def __init__( self, @@ -670,6 +626,7 @@ def __init__( self.send2dst = send2dst self.storage_level = storage_level + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -677,8 +634,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.aggregate_messages.CopyFrom( pb.AggregateMessages( agg_col=[make_column_or_expr(x, session) for x in self.agg_col], - send_to_src=[make_column_or_expr(x, session) for x in self.send2src], - send_to_dst=[make_column_or_expr(x, session) for x in self.send2dst], + send_to_src=[ + make_column_or_expr(x, session) for x in self.send2src + ], + send_to_dst=[ + make_column_or_expr(x, session) for x in self.send2dst + ], storage_level=storage_level_to_proto(self.storage_level), ) ) @@ -686,8 +647,10 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - if sendToSrc is None and sendToDst is None: - raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") + if (len(sendToSrc) == 0) and (len(sendToDst) == 0): + raise ValueError( + "Either `sendToSrc`, `sendToDst`, or both have to be provided" + ) return _dataframe_from_plan( AggregateMessages( @@ -703,11 +666,15 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def connectedComponents( self, - algorithm: str = "graphframes", - checkpointInterval: int = 2, - broadcastThreshold: int = 1000000, - useLabelsAsComponents: bool = False, + algorithm: str, + checkpointInterval: int, + broadcastThreshold: int, + useLabelsAsComponents: bool, + use_local_checkpoints: bool, + max_iter: int, + storage_level: StorageLevel, ) -> DataFrame: + @final class ConnectedComponents(LogicalPlan): def __init__( self, @@ -717,6 +684,9 @@ def __init__( checkpoint_interval: int, broadcast_threshold: int, use_labels_as_components: bool, + use_local_checkpoints: bool, + max_iter: int, + storage_level: StorageLevel, ) -> None: super().__init__(None) self.v = v @@ -725,7 +695,11 @@ def __init__( self.checkpoint_interval = checkpoint_interval self.broadcast_threshold = broadcast_threshold self.use_labels_as_components = use_labels_as_components + self.use_local_checkpoints = use_local_checkpoints + self.max_iter = max_iter + self.storage_level = storage_level + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -736,6 +710,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: checkpoint_interval=self.checkpoint_interval, broadcast_threshold=self.broadcast_threshold, use_labels_as_components=self.use_labels_as_components, + use_local_checkpoints=self.use_local_checkpoints, + max_iter=self.max_iter, + storage_level=storage_level_to_proto(self.storage_level), ) ) plan = self._create_proto_relation() @@ -750,34 +727,76 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: checkpointInterval, broadcastThreshold, useLabelsAsComponents, + use_local_checkpoints, + max_iter, + storage_level, ), self._spark, ) - def labelPropagation(self, maxIter: int) -> DataFrame: + def labelPropagation( + self, + maxIter: int, + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> DataFrame: + @final class LabelPropagation(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: + def __init__( + self, + v: DataFrame, + e: DataFrame, + max_iter: int, + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> None: super().__init__(None) self.v = v self.e = e self.max_iter = max_iter + self.algorithm = algorithm + self.use_local_checkpoints = use_local_checkpoints + self.checkpoint_interval = checkpoint_interval + self.storage_level = storage_level + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) graphframes_api_call.label_propagation.CopyFrom( - pb.LabelPropagation(max_iter=self.max_iter) + pb.LabelPropagation( + algorithm=self.algorithm, + max_iter=self.max_iter, + use_local_checkpoints=self.use_local_checkpoints, + checkpoint_interval=self.checkpoint_interval, + storage_level=storage_level_to_proto(self.storage_level), + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan return _dataframe_from_plan( - LabelPropagation(self._vertices, self._edges, maxIter), self._spark + LabelPropagation( + self._vertices, + self._edges, + maxIter, + algorithm, + use_local_checkpoints, + checkpoint_interval, + storage_level, + ), + self._spark, ) - def _update_page_rank_edge_weights(self, new_vertices: DataFrame) -> "GraphFrameConnect": + def _update_page_rank_edge_weights( + self, new_vertices: DataFrame + ) -> "GraphFrameConnect": cols2select = self.edges.columns + ["weight"] new_edges = ( self._edges.join( @@ -802,6 +821,7 @@ def pageRank( maxIter: int | None = None, tol: float | None = None, ) -> "GraphFrameConnect": + @final class PageRank(LogicalPlan): def __init__( self, @@ -820,6 +840,7 @@ def __init__( self.max_iter = max_iter self.tol = tol + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -828,7 +849,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: pb.PageRank( reset_probability=self.reset_prob, source_id=( - None if self.source_id is None else make_str_or_long_id(self.source_id) + None + if self.source_id is None + else make_str_or_long_id(self.source_id) ), max_iter=self.max_iter, tol=self.tol, @@ -863,6 +886,7 @@ def parallelPersonalizedPageRank( sourceIds: list[str | int] | None = None, maxIter: int | None = None, ) -> "GraphFrameConnect": + @final class ParallelPersonalizedPageRank(LogicalPlan): def __init__( self, @@ -879,6 +903,7 @@ def __init__( self.source_ids = source_ids self.max_iter = max_iter + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -886,7 +911,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.parallel_personalized_page_rank.CopyFrom( pb.ParallelPersonalizedPageRank( reset_probability=self.reset_prob, - source_ids=[make_str_or_long_id(raw_id) for raw_id in self.source_ids], + source_ids=[ + make_str_or_long_id(raw_id) for raw_id in self.source_ids + ], max_iter=self.max_iter, ) ) @@ -894,9 +921,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - assert ( - sourceIds is not None and len(sourceIds) > 0 - ), "Source vertices Ids sourceIds must be provided" + assert sourceIds is not None and len(sourceIds) > 0, ( + "Source vertices Ids sourceIds must be provided" + ) assert maxIter is not None, "Max number of iterations maxIter must be provided" new_vertices = _dataframe_from_plan( @@ -914,6 +941,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def powerIterationClustering( self, k: int, maxIter: int, weightCol: str | None = None ) -> DataFrame: + @final class PowerIterationClustering(LogicalPlan): def __init__( self, @@ -930,6 +958,7 @@ def __init__( self.max_iter = max_iter self.weight_col = weight_col + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -946,25 +975,55 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan return _dataframe_from_plan( - PowerIterationClustering(self._vertices, self._edges, k, maxIter, weightCol), + PowerIterationClustering( + self._vertices, self._edges, k, maxIter, weightCol + ), self._spark, ) - def shortestPaths(self, landmarks: list[str | int]) -> DataFrame: + def shortestPaths( + self, + landmarks: list[str | int], + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> DataFrame: + @final class ShortestPaths(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame, landmarks: list[str | int]) -> None: + def __init__( + self, + v: DataFrame, + e: DataFrame, + landmarks: list[str | int], + algorithm: str, + use_local_checkpoints: bool, + checkpoint_interval: int, + storage_level: StorageLevel, + ) -> None: super().__init__(None) self.v = v self.e = e self.landmarks = landmarks + self.algorithm = algorithm + self.use_local_checkpoints = use_local_checkpoints + self.checkpoint_interval = checkpoint_interval + self.storage_level = storage_level + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) graphframes_api_call.shortest_paths.CopyFrom( pb.ShortestPaths( - landmarks=[make_str_or_long_id(raw_id) for raw_id in self.landmarks] + landmarks=[ + make_str_or_long_id(raw_id) for raw_id in self.landmarks + ], + algorithm=self.algorithm, + use_local_checkpoints=self.use_local_checkpoints, + checkpoint_interval=self.checkpoint_interval, + storage_level=storage_level_to_proto(self.storage_level), ) ) plan = self._create_proto_relation() @@ -972,10 +1031,20 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan return _dataframe_from_plan( - ShortestPaths(self._vertices, self._edges, landmarks), self._spark + ShortestPaths( + self._vertices, + self._edges, + landmarks, + algorithm, + use_local_checkpoints, + checkpoint_interval, + storage_level, + ), + self._spark, ) def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: + @final class StronglyConnectedComponents(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: super().__init__(None) @@ -983,6 +1052,7 @@ def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: self.e = e self.max_iter = max_iter + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -1011,6 +1081,7 @@ def svdPlusPlus( gamma7: float = 0.015, return_loss: bool = False, # TODO: should it be True to mimic the classic API? ) -> tuple[DataFrame, float]: + @final class SVDPlusPlus(LogicalPlan): def __init__( self, @@ -1037,6 +1108,7 @@ def __init__( self.gamma6 = gamma6 self.gamma7 = gamma7 + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session @@ -1079,20 +1151,31 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: else: return (output.drop("loss"), -1.0) - def triangleCount(self) -> DataFrame: + def triangleCount(self, storage_level: StorageLevel) -> DataFrame: + @final class TriangleCount(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame) -> None: + def __init__( + self, v: DataFrame, e: DataFrame, storage_level: StorageLevel + ) -> None: super().__init__(None) self.v = v self.e = e + self.storage_level = storage_level + @override def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.triangle_count.CopyFrom(pb.TriangleCount()) + graphframes_api_call.triangle_count.CopyFrom( + pb.TriangleCount( + storage_level=storage_level_to_proto(self.storage_level) + ) + ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan - return _dataframe_from_plan(TriangleCount(self._vertices, self._edges), self._spark) + return _dataframe_from_plan( + TriangleCount(self._vertices, self._edges), self._spark + ) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index b32014ae9..b82d3aeab 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any from pyspark.storagelevel import StorageLevel +from pyspark.sql import functions as F from pyspark.version import __version__ from typing_extensions import override @@ -32,15 +33,25 @@ def is_remote() -> bool: return False -from pyspark.sql import SparkSession - from graphframes.classic.graphframe import GraphFrame as GraphFrameClassic from graphframes.lib import Pregel if TYPE_CHECKING: from pyspark.sql import Column, DataFrame - from graphframes.connect.graphframe_client import GraphFrameConnect + from graphframes.connect.graphframes_client import GraphFrameConnect + +"""Constant for the vertices ID column name.""" +ID = "id" + +"""Constant for the edge src column name.""" +SRC = "src" + +"""Constant for the edge dst column name.""" +DST = "dst" + +"""Constant for the edge column name.""" +EDGE = "edge" class GraphFrame: @@ -62,17 +73,27 @@ class GraphFrame: """ @staticmethod - def _from_impl(impl: GraphFrameClassic | "GraphFrameConnect") -> "GraphFrame": + def _from_impl(impl: "GraphFrameClassic | GraphFrameConnect") -> "GraphFrame": return GraphFrame(impl.vertices, impl.edges) def __init__(self, v: DataFrame, e: DataFrame) -> None: - self._impl: GraphFrameClassic | "GraphFrameConnect" + """ + Initialize a GraphFrame from vertex DataFrame and edges DataFrame. + + :param v: :class:`DataFrame` holding vertex information. + Must contain a column named "id" that stores unique + vertex IDs. + :param e: :class:`DataFrame` holding edge information. + Must contain two columns "src" and "dst" storing source + vertex IDs and destination vertex IDs of edges, respectively. + """ + self._impl: "GraphFrameClassic | GraphFrameConnect" if is_remote(): - from graphframes.connect.graphframe_client import GraphFrameConnect + from graphframes.connect.graphframes_client import GraphFrameConnect - self._impl = GraphFrameConnect(v, e) + self._impl = GraphFrameConnect(v, e) # ty: ignore[invalid-argument-type] else: - self._impl = GraphFrameClassic(v, e) + self._impl = GraphFrameClassic(v, e) # ty: ignore[invalid-argument-type] @property def vertices(self) -> DataFrame: @@ -106,7 +127,9 @@ def cache(self) -> "GraphFrame": """ return GraphFrame._from_impl(self._impl.cache()) - def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": + def persist( + self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + ) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the given storage level. """ @@ -174,7 +197,9 @@ def triplets(self) -> DataFrame: @property def pregel(self) -> Pregel: """ - Get the :class:`graphframes.lib.Pregel` object for running pregel. + Get the :class:`graphframes.classic.pregel.Pregel` + or :class`graphframes.connect.graphframes_client.Pregel` + object for running pregel. See :class:`graphframes.lib.Pregel` for more details. """ @@ -218,6 +243,35 @@ def dropIsolatedVertices(self) -> "GraphFrame": """ return GraphFrame._from_impl(self._impl.dropIsolatedVertices()) + def detectingCycles( + self, + checkpoint_interval: int = 2, + use_local_checkpoints: bool = False, + storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, + ) -> DataFrame: + """Find all cycles in the graph. + + An implementation of the Rocha–Thatte cycle detection algorithm. + Rocha, Rodrigo Caetano, and Bhalchandra D. Thatte. "Distributed cycle detection in + large-scale sparse graphs." Proceedings of Simpósio Brasileiro de Pesquisa Operacional + (SBPO’15) (2015): 1-11. + + Returns a DataFrame with ID and cycles, ID are not unique if there are multiple cycles + starting from this ID. For the case of cycle 1 -> 2 -> 3 -> 1 all the vertices will have the + same cycle! E.g.: 1 -> [1, 2, 3, 1] 2 -> [2, 3, 1, 2] 3 -> [3, 1, 2, 3] + + Deduplication of cycles should be done by the user! + + :param checkpoint_interval: Pregel checkpoint interval, default is 2 + :param use_local_checkpoints: should local checkpoints be used instead of checkpointDir + :storage_level: the level of storage for both intermediate results and an output DataFrame + + :return: Persisted DataFrame with all the cycles + """ + return self._impl.detectingCycles( + checkpoint_interval, use_local_checkpoints, storage_level + ) + def bfs( self, fromExpr: str, @@ -242,9 +296,9 @@ def bfs( def aggregateMessages( self, aggCol: list[Column | str] | Column, - sendToSrc: list[Column | str] | Column | str = list(), - sendToDst: list[Column | str] | Column | str = list(), - intermediate_storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK, + sendToSrc: list[Column | str] | Column | str | None = None, + sendToDst: list[Column | str] | Column | str | None = None, + intermediate_storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, ) -> DataFrame: """ Aggregates messages from the neighbours. @@ -266,10 +320,15 @@ def aggregateMessages( :param intermediate_storage_level: the level of intermediate storage that will be used for both intermediate result and the output. - :return: DataFrame with columns for the vertex ID and the resulting aggregated message. + :return: Persisted DataFrame with columns for the vertex ID and the resulting aggregated message. The name of the resulted message column is based on the alias of the provided aggCol! """ # noqa: E501 + if sendToDst is None: + sendToDst = [] + if sendToSrc is None: + sendToSrc = [] + # Back-compatibility workaround if not isinstance(aggCol, list): warnings.warn( @@ -300,7 +359,9 @@ def aggregateMessages( raise TypeError("At least one aggregation column should be provided!") if (len(sendToSrc) == 0) and (len(sendToDst) == 0): - raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") + raise ValueError( + "Either `sendToSrc`, `sendToDst`, or both have to be provided" + ) return self._impl.aggregateMessages( aggCol=aggCol, sendToSrc=sendToSrc, @@ -316,6 +377,9 @@ def connectedComponents( checkpointInterval: int = 2, broadcastThreshold: int = 1000000, useLabelsAsComponents: bool = False, + use_local_checkpoints: bool = False, + max_iter: int = 2 ^ 31 - 2, + storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, ) -> DataFrame: """ Computes the connected components of the graph. @@ -337,6 +401,9 @@ def connectedComponents( checkpointInterval=checkpointInterval, broadcastThreshold=broadcastThreshold, useLabelsAsComponents=useLabelsAsComponents, + use_local_checkpoints=use_local_checkpoints, + max_iter=max_iter, + storage_level=storage_level, ) def labelPropagation(self, maxIter: int) -> DataFrame: @@ -480,21 +547,86 @@ def powerIterationClustering( """ # noqa: E501 return self._impl.powerIterationClustering(k, maxIter, weightCol) + def validate( + self, + check_vertices: bool = True, + intermediate_storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, + ) -> None: + """ + Validates the consistency and integrity of a graph by performing checks on the vertices and + edges. + + :param check_vertices: a flag to indicate whether additional vertex consistency checks + should be performed. If true, the method will verify that all vertices in the vertex + DataFrame are represented in the edge DataFrame and vice versa. It is slow on big graphs. + :param intermediate_storage_level: the storage level to be used when persisting + intermediate DataFrame computations during the validation process. + :return: Unit, as the method performs validation checks and throws an exception if + validation fails. + :raises ValueError: if there are any inconsistencies in the graph, such as duplicate + vertices, mismatched vertices between edges and vertex DataFrames or missing + connections. + """ + persisted_vertices = self.vertices.persist(intermediate_storage_level) + row = persisted_vertices.select(F.count_distinct(F.col(ID))).first() + assert row is not None # for type checker + count_distinct_vertices = row[0] + assert isinstance(count_distinct_vertices, int) # for type checker + total_count_vertices = persisted_vertices.count() + if count_distinct_vertices != total_count_vertices: + raise ValueError( + f"Graph contains ({total_count_vertices - count_distinct_vertices}) duplicate vertices." + ) + if check_vertices: + vertices_set_from_edges = ( + self.edges.select(F.col(SRC).alias(ID)) + .union(self.edges.select(F.col(DST).alias(ID))) + .distinct() + .persist(intermediate_storage_level) + ) + count_vertices_from_edges = vertices_set_from_edges.count() + if count_vertices_from_edges > count_distinct_vertices: + raise ValueError( + f"Graph is inconsistent: edges has {count_vertices_from_edges} " + + f"vertices, but vertices has {count_distinct_vertices} vertices." + ) -def _test(): - import doctest + combined = vertices_set_from_edges.join(self.vertices, ID, "left_anti") + count_of_bad_vertices = combined.count() + if count_of_bad_vertices > 0: + raise ValueError( + "Vertices DataFrame does not contain all edges src/dst. " + + f"Found {count_of_bad_vertices} edges src/dst that are not in the vertices DataFrame." + ) + _ = persisted_vertices.unpersist() + _ = vertices_set_from_edges.unpersist() - import graphframe + def as_undirected(self) -> "GraphFrame": + """ + Converts the directed graph into an undirected graph by ensuring that all directed edges are + bidirectional. For every directed edge (src, dst), a corresponding edge (dst, src) is added. + + :return: A new GraphFrame representing the undirected graph. + """ + + edge_attr_columns = [c for c in self.edges.columns if c not in [SRC, DST]] + + # Create the undirected edges by duplicating each edge in both directions + forward_edges = self.edges.select( + F.col(SRC), F.col(DST), F.struct(*edge_attr_columns).alias(EDGE) + ) + backward_edges = self.edges.select( + F.col(DST).alias(SRC), + F.col(SRC).alias(DST), + F.struct(*edge_attr_columns).alias(EDGE), + ) + new_edges = forward_edges.union(backward_edges).select(SRC, DST, EDGE) - globs = graphframe.__dict__.copy() - globs["spark"] = SparkSession.builder.master("local[4]").appName("PythonTest").getOrCreate() - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE - ) - globs["spark"].stop() - if failure_count: - exit(-1) + # Preserve additional edge attributes + edge_columns = [F.col(EDGE).getField(c).alias(c) for c in edge_attr_columns] + # Select all columns including the new edge attributes + selected_columns = [F.col(SRC), F.col(DST)] + edge_columns + new_edges = new_edges.select(*selected_columns) -if __name__ == "__main__": - _test() + return GraphFrame(self.vertices, new_edges) diff --git a/python/graphframes/lib/__init__.py b/python/graphframes/lib/__init__.py new file mode 100644 index 000000000..076dd5232 --- /dev/null +++ b/python/graphframes/lib/__init__.py @@ -0,0 +1,4 @@ +from .aggregate_messages import AggregateMessages +from .pregel import Pregel + +__all__ = ["AggregateMessages", "Pregel"] diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py new file mode 100644 index 000000000..a164aeca6 --- /dev/null +++ b/python/graphframes/lib/aggregate_messages.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from pyspark.sql import Column +from pyspark.sql import functions as F + +from graphframes import graphframe + + +class AggregateMessages: + """Collection of utilities usable with :meth:`graphframes.GraphFrame.aggregateMessages()`.""" + + @staticmethod + def src() -> Column: + """Reference for source column, used for specifying messages.""" + return F.col(graphframe.SRC) + + @staticmethod + def dst() -> Column: + """Reference for destination column, used for specifying messages.""" + return F.col(graphframe.DST) + + @staticmethod + def edge() -> Column: + """Reference for edge column, used for specifying messages.""" + return F.col(graphframe.EDGE) + + @staticmethod + def msg() -> Column: + """Reference for message column, used for specifying aggregation function.""" + return F.col("MSG") diff --git a/python/graphframes/classic/pregel.py b/python/graphframes/lib/pregel.py similarity index 58% rename from python/graphframes/classic/pregel.py rename to python/graphframes/lib/pregel.py index 11bf12e5e..2dfdf7888 100644 --- a/python/graphframes/classic/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -15,23 +15,23 @@ # limitations under the License. # -import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, final -if sys.version > "3": - basestring = str +from graphframes.classic.utils import storage_level_to_jvm from pyspark.ml.wrapper import JavaWrapper from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col if TYPE_CHECKING: - from graphframes import GraphFrame + from graphframes.classic.graphframe import GraphFrame + from pyspark.sql import Column + from pyspark.storagelevel import StorageLevel +@final class Pregel(JavaWrapper): - """ - Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. + """Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. See `Malewicz et al., Pregel: a system for large-scale graph processing `_ for a detailed description of the Pregel algorithm. @@ -54,49 +54,26 @@ class Pregel(JavaWrapper): You can control the number of iterations by :func:`setMaxIter` and check API docs for advanced controls. :param graph: a :class:`graphframes.GraphFrame` object holding a graph with vertices and edges stored as DataFrames. - - >>> from graphframes import GraphFrame - >>> from pyspark.sql.functions import coalesce, col, lit, sum, when - >>> from graphframes.lib import Pregel - >>> edges = spark.createDataFrame([[0, 1], - ... [1, 2], - ... [2, 4], - ... [2, 0], - ... [3, 4], # 3 has no in-links - ... [4, 0], - ... [4, 2]], ["src", "dst"]) - >>> edges.cache() - >>> vertices = spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) - >>> numVertices = vertices.count() - >>> vertices = GraphFrame(vertices, edges).outDegrees - >>> vertices.cache() - >>> graph = GraphFrame(vertices, edges) - >>> alpha = 0.15 - >>> ranks = graph.pregel \ - ... .setMaxIter(5) \ - ... .withVertexColumn("rank", lit(1.0 / numVertices), \ - ... coalesce(Pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices)) \ - ... .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) \ - ... .aggMsgs(sum(Pregel.msg())) \ - ... .run() """ # noqa: E501 def __init__(self, graph: "GraphFrame") -> None: super(Pregel, self).__init__() self.graph = graph - self._java_obj = self._new_java_obj("org.graphframes.lib.Pregel", graph._jvm_graph) + self._java_obj = self._new_java_obj( + "org.graphframes.lib.Pregel", graph._jvm_graph + ) def setMaxIter(self, value: int) -> "Pregel": - """ - Sets the max number of iterations (default: 10). + """Sets the max number of iterations (default: 10). + + :param value: the number of Pregel iterations """ self._java_obj.setMaxIter(int(value)) return self def setCheckpointInterval(self, value: int) -> "Pregel": - """ - Sets the number of iterations between two checkpoints (default: 2). + """Sets the number of iterations between two checkpoints (default: 2). This is an advanced control to balance query plan optimization and checkpoint data I/O cost. In most cases, you should keep the default value. @@ -107,8 +84,7 @@ def setCheckpointInterval(self, value: int) -> "Pregel": return self def setEarlyStopping(self, value: bool) -> "Pregel": - """ - Set should Pregel stop earlier in case of no new messages to send or not. + """Set should Pregel stop earlier in case of no new messages to send or not. Early stopping allows to terminate Pregel before reaching maxIter by checking if there are any non-null messages. While in some cases it may gain significant performance boost, in other cases it can lead to performance degradation, @@ -124,10 +100,9 @@ def setEarlyStopping(self, value: bool) -> "Pregel": return self def withVertexColumn( - self, colName: str, initialExpr: Any, updateAfterAggMsgsExpr: Any + self, colName: str, initialExpr: Column, updateAfterAggMsgsExpr: Column ) -> "Pregel": - """ - Defines an additional vertex column at the start of run and how to update it in each iteration. + """Defines an additional vertex column at the start of run and how to update it in each iteration. You can call it multiple times to add more than one additional vertex columns. @@ -140,12 +115,13 @@ def withVertexColumn( aggregated message column using :func:`msg`. If the vertex received no messages, the message column would be null. """ # noqa: E501 - self._java_obj.withVertexColumn(colName, initialExpr._jc, updateAfterAggMsgsExpr._jc) + self._java_obj.withVertexColumn( + colName, initialExpr._jc, updateAfterAggMsgsExpr._jc + ) return self - def sendMsgToSrc(self, msgExpr: Any) -> "Pregel": - """ - Defines a message to send to the source vertex of each edge triplet. + def sendMsgToSrc(self, msgExpr: Column) -> "Pregel": + """Defines a message to send to the source vertex of each edge triplet. You can call it multiple times to send more than one messages. @@ -160,9 +136,8 @@ def sendMsgToSrc(self, msgExpr: Any) -> "Pregel": self._java_obj.sendMsgToSrc(msgExpr._jc) return self - def sendMsgToDst(self, msgExpr: Any) -> "Pregel": - """ - Defines a message to send to the destination vertex of each edge triplet. + def sendMsgToDst(self, msgExpr: Column) -> "Pregel": + """Defines a message to send to the destination vertex of each edge triplet. You can call it multiple times to send more than one messages. @@ -177,9 +152,8 @@ def sendMsgToDst(self, msgExpr: Any) -> "Pregel": self._java_obj.sendMsgToDst(msgExpr._jc) return self - def aggMsgs(self, aggExpr: Any) -> "Pregel": - """ - Defines how messages are aggregated after grouped by target vertex IDs. + def aggMsgs(self, aggExpr: Column) -> "Pregel": + """Defines how messages are aggregated after grouped by target vertex IDs. :param aggExpr: the message aggregation expression, such as `sum(Pregel.msg())`. You can reference the message column by :func:`msg` and the vertex ID by `col("id")`, @@ -188,27 +162,100 @@ def aggMsgs(self, aggExpr: Any) -> "Pregel": self._java_obj.aggMsgs(aggExpr._jc) return self - def run(self) -> DataFrame: + def setStopIfAllNonActiveVertices(self, value: bool) -> Self: + """Set should Pregel stop if all the vertices voted to halt. + + Activity (or vote) is determined based on the activity_col. + See methods :func:`setInitialActiveVertexExpression` and :func:`setUpdateActiveVertexExpression` for details + how to set and update activity_col. + + Be aware that checking of the vote is not free but a Spark Action. In case the + condition is not realistically reachable but set, it will just slow down the algorithm. + + :param value: the boolean value. """ - Runs the defined Pregel algorithm. + self._java_obj.setStopIfAllNonActiveVertices(value) + return self + + def setInitialActiveVertexExpression(self, value: Column) -> Self: + """Sets the initial expression for the active vertex column. + + The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. + This expression is evaluated on the initial vertices DataFrame to set the initial state of the activity column. + + :param value: expression to compute the initial active state of vertices. + You can reference all original vertex columns in this expression. + """ # noqa: E501 + self._java_obj.setInitialActiveVertexExpression(value._jc) + return self + + def setUpdateActiveVertexExpression(self, value: Column) -> Self: + """Sets the expression to update the active vertex column. + + The active vertex column is used to determine if a vertices voting result on each iteration of Pregel. + This expression is evaluated on the updated vertices DataFrame to set the new state of the activity column. + + :param value: expression to compute the new active state of vertices. + You can reference all original vertex columns and additional vertex columns in this expression. + """ # noqa: E501 + self._java_obj.setUpdateActiveVertexExpression(value._jc) + return self + + def setSkipMessagesFromNonActiveVertices(self, value: bool) -> Self: + """Set should Pregel skip sending messages from non-active vertices. + + When this option is enabled, messages will not be sent from vertices that are marked as inactive. + This can help optimize performance by avoiding unnecessary message propagation from inactive vertices. + + :param value: boolean value. + """ # noqa: E501 + self._java_obj.setSkipMessagesFromNonActiveVertices(value) + return self + + def setUseLocalCheckpoints(self, value: bool) -> Self: + """Set should Pregel use local checkpoints. + + Local checkpoints are faster and do not require configuring a persistent storage. + At the same time, local checkpoints are less reliable and may create a big load on local disks of executors. + + :param value: boolean value. + """ # noqa: E501 + self._java_obj.setUseLocalCheckpoints(value) + return self + + def setIntermediateStorageLevel(self, storage_level: StorageLevel) -> Self: + """Set the intermediate storage level. + On each iteration, Pregel cache results with a requested storage level. + + For very big graphs it is recommended to use DISK_ONLY. + + :param storage_level: storage level to use. + """ # noqa: E501 + self._java_obj.setIntermediateStorageLevel( + storage_level_to_jvm(storage_level, self.graph.vertices.sparkSession) + ) + + def run(self) -> DataFrame: + """Runs the defined Pregel algorithm. :return: the result vertex DataFrame from the final iteration including both original and additional columns. """ # noqa: E501 - return DataFrame(self._java_obj.run(), SparkSession.getActiveSession()) + spark = SparkSession.getActiveSession() + if spark is None: + raise ValueError("SparkSession is dead or did not started.") + return DataFrame(self._java_obj.run(), spark) @staticmethod - def msg() -> Any: - """ - References the message column in aggregating messages and updating additional vertex columns. + def msg() -> Column: + """References the message column in aggregating messages and updating additional vertex columns. See :func:`aggMsgs` and :func:`withVertexColumn` """ # noqa: E501 return col("_pregel_msg_") @staticmethod - def src(colName: str) -> Any: - """ - References a source vertex column in generating messages to send. + def src(colName: str) -> Column: + """References a source vertex column in generating messages to send. See :func:`sendMsgToSrc` and :func:`sendMsgToDst` @@ -217,7 +264,7 @@ def src(colName: str) -> Any: return col("src." + colName) @staticmethod - def dst(colName: str) -> Any: + def dst(colName: str) -> Column: """ References a destination vertex column in generating messages to send. @@ -228,7 +275,7 @@ def dst(colName: str) -> Any: return col("dst." + colName) @staticmethod - def edge(colName: str) -> Any: + def edge(colName: str) -> Column: """ References an edge column in generating messages to send. From eaabc8f0498a038f2fff8c90a4d12573e0e0bec5 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 4 Oct 2025 10:18:22 +0200 Subject: [PATCH 07/17] simplify py api --- python/graphframes/classic/graphframe.py | 98 +-------- python/graphframes/classic/utils.py | 4 +- .../graphframes/connect/graphframes_client.py | 191 ++++-------------- python/graphframes/graphframe.py | 103 +++++++--- python/graphframes/lib/pregel.py | 18 +- 5 files changed, 128 insertions(+), 286 deletions(-) diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index 73cf2ab16..f9191f8cb 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -18,16 +18,15 @@ from typing import final - from py4j.java_gateway import JavaObject +from pyspark.core.context import SparkContext +from pyspark.sql import SparkSession from pyspark.sql.classic.column import Column, _to_seq from pyspark.sql.classic.dataframe import DataFrame -from pyspark.sql import SparkSession -from pyspark.core.context import SparkContext from pyspark.storagelevel import StorageLevel -from graphframes.lib import Pregel from graphframes.classic.utils import storage_level_to_jvm +from graphframes.lib import Pregel def _from_java_gf(jgf: JavaObject, spark: SparkSession) -> "GraphFrame": @@ -64,74 +63,10 @@ def __init__(self, v: DataFrame, e: DataFrame) -> None: self._sc = self._spark._sc self._jvm_gf_api = _java_api(self._sc) - self.ID: str = "id" - self.SRC: str = "src" - self.DST: str = "edge" self._ATTR: str = self._jvm_gf_api.ATTR() - # Check that provided DataFrames contain required columns - if self.ID not in v.columns: - raise ValueError( - "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( - self.ID, ",".join(v.columns) - ) - ) - if self.SRC not in e.columns: - raise ValueError( - "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 - self.SRC, ",".join(e.columns) - ) - ) - if self.DST not in e.columns: - raise ValueError( - "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 - self.DST, ",".join(e.columns) - ) - ) - self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) - @property - def vertices(self) -> DataFrame: - return self._vertices - - @property - def edges(self) -> DataFrame: - return self._edges - - def __repr__(self) -> str: - return self._jvm_graph.toString() - - def cache(self) -> "GraphFrame": - self._jvm_graph.cache() - return self - - def persist( - self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY - ) -> "GraphFrame": - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jvm_graph.persist(javaStorageLevel) - return self - - def unpersist(self, blocking: bool = False) -> "GraphFrame": - self._jvm_graph.unpersist(blocking) - return self - - @property - def outDegrees(self) -> DataFrame: - jdf = self._jvm_graph.outDegrees() - return DataFrame(jdf, self._spark) - - @property - def inDegrees(self) -> DataFrame: - jdf = self._jvm_graph.inDegrees() - return DataFrame(jdf, self._spark) - - @property - def degrees(self) -> DataFrame: - jdf = self._jvm_graph.degrees() - return DataFrame(jdf, self._spark) - @property def triplets(self) -> DataFrame: jdf = self._jvm_graph.triplets() @@ -191,10 +126,7 @@ def bfs( maxPathLength: int = 10, ) -> DataFrame: builder = ( - self._jvm_graph.bfs() - .fromExpr(fromExpr) - .toExpr(toExpr) - .maxPathLength(maxPathLength) + self._jvm_graph.bfs().fromExpr(fromExpr).toExpr(toExpr).maxPathLength(maxPathLength) ) if edgeFilter is not None: builder.edgeFilter(edgeFilter) @@ -255,9 +187,7 @@ def aggregateMessages( jdf = builder.aggCol(aggCol[0]) elif len(aggCol) > 1: if all(isinstance(x, Column) for x in aggCol): - jdf = builder.aggCol( - aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol]) - ) + jdf = builder.aggCol(aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol])) elif all(isinstance(x, str) for x in aggCol): jdf = builder.aggCol(aggCol[0], _to_seq(self._sc, aggCol[1:])) else: @@ -284,9 +214,7 @@ def connectedComponents( .setUseLabelsAsComponents(useLabelsAsComponents) .setUseLocalCheckpoints(use_local_checkpoints) .maxIter(max_iter) - .setIntermediateStorageLevel( - storage_level_to_jvm(storage_level, self._spark) - ) + .setIntermediateStorageLevel(storage_level_to_jvm(storage_level, self._spark)) .run() ) return DataFrame(jdf, self._spark) @@ -305,9 +233,7 @@ def labelPropagation( .setAlgorithm(algorithm) .setUseLocalCheckpoints(use_local_checkpoints) .setCheckpointInterval(checkpoint_interval) - .setIntermediateStorageLevel( - storage_level_to_jvm(storage_level, self._spark) - ) + .setIntermediateStorageLevel(storage_level_to_jvm(storage_level, self._spark)) .run() ) return DataFrame(jdf, self._spark) @@ -337,9 +263,9 @@ def parallelPersonalizedPageRank( sourceIds: list[str | int] | None = None, maxIter: int | None = None, ) -> "GraphFrame": - assert sourceIds is not None and len(sourceIds) > 0, ( - "Source vertices Ids sourceIds must be provided" - ) + assert ( + sourceIds is not None and len(sourceIds) > 0 + ), "Source vertices Ids sourceIds must be provided" assert maxIter is not None, "Max number of iterations maxIter must be provided" sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) builder = self._jvm_graph.parallelPersonalizedPageRank() @@ -363,9 +289,7 @@ def shortestPaths( .setAlgorithm(algorithm) .setUseLocalCheckpoints(use_local_checkpoints) .setCheckpointInterval(checkpoint_interval) - .setIntermediateStorageLevel( - storage_level_to_jvm(storage_level, self._spark) - ) + .setIntermediateStorageLevel(storage_level_to_jvm(storage_level, self._spark)) .run() ) return DataFrame(jdf, self._spark) diff --git a/python/graphframes/classic/utils.py b/python/graphframes/classic/utils.py index b6b39ddea..863ca0c5f 100644 --- a/python/graphframes/classic/utils.py +++ b/python/graphframes/classic/utils.py @@ -6,9 +6,7 @@ from pyspark.sql.classic.dataframe import SparkSession -def storage_level_to_jvm( - storage_level: StorageLevel, spark: SparkSession -) -> JavaObject: +def storage_level_to_jvm(storage_level: StorageLevel, spark: SparkSession) -> JavaObject: return spark._jvm.org.apache.spark.storage.StorageLevel.apply( storage_level.useDisk, storage_level.useMemory, diff --git a/python/graphframes/connect/graphframes_client.py b/python/graphframes/connect/graphframes_client.py index 68ba8d00e..da53c8a50 100644 --- a/python/graphframes/connect/graphframes_client.py +++ b/python/graphframes/connect/graphframes_client.py @@ -1,6 +1,6 @@ from __future__ import annotations + from typing import final -from typing_extensions import override from pyspark.sql.connect import functions as F from pyspark.sql.connect import proto @@ -10,6 +10,7 @@ from pyspark.sql.connect.plan import LogicalPlan from pyspark.sql.connect.session import SparkSession from pyspark.storagelevel import StorageLevel +from typing_extensions import override try: from typing import Self @@ -171,25 +172,17 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: checkpoint_interval=self.checkpoint_interval, max_iter=self.max_iter, additional_col_name=self.vertex_col_name, - additional_col_initial=make_column_or_expr( - self.vertex_col_init, session - ), - additional_col_upd=make_column_or_expr( - self.vertex_col_upd, session - ), + additional_col_initial=make_column_or_expr(self.vertex_col_init, session), + additional_col_upd=make_column_or_expr(self.vertex_col_upd, session), early_stopping=self.early_stopping, use_local_checkpoints=self.use_local_checkpoints, storage_level=storage_level_to_proto(self.storage_level), stop_if_all_non_active=self.stop_if_all_non_active, skip_messages_from_non_active=self.skip_message_from_non_active, - initial_active_expr=make_column_or_expr( - self.initial_active_expr, session - ) + initial_active_expr=make_column_or_expr(self.initial_active_expr, session) if self.initial_active_expr is not None else None, - update_active_expr=make_column_or_expr( - self.update_active_expr, session - ) + update_active_expr=make_column_or_expr(self.update_active_expr, session) if self.update_active_expr is not None else None, ) @@ -254,35 +247,16 @@ def edge(colName: str) -> Column: @final class GraphFrameConnect: - ID: str = "id" - SRC: str = "src" - DST: str = "dst" - EDGE: str = "edge" + _ID: str = "id" + _SRC: str = "src" + _DST: str = "dst" + _EDGE: str = "edge" def __init__(self, v: DataFrame, e: DataFrame) -> None: self._vertices = v self._edges = e self._spark = v.sparkSession - if self.ID not in v.columns: - raise ValueError( - "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( - self.ID, ",".join(v.columns) - ) - ) - if self.SRC not in e.columns: - raise ValueError( - "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 - self.SRC, ",".join(e.columns) - ) - ) - if self.DST not in e.columns: - raise ValueError( - "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 - self.DST, ",".join(e.columns) - ) - ) - @staticmethod def _get_pb_api_message( vertices: DataFrame, edges: DataFrame, client: SparkConnectClient @@ -292,65 +266,6 @@ def _get_pb_api_message( edges=dataframe_to_proto(edges, client), ) - @property - def vertices(self) -> DataFrame: - return self._vertices - - @property - def edges(self) -> DataFrame: - return self._edges - - @override - def __repr__(self) -> str: - # Exactly like in the scala core - v_cols = [self.ID] + [col for col in self.vertices.columns if col != self.ID] - e_cols = [self.SRC, self.DST] + [ - col for col in self.edges.columns if col not in {self.SRC, self.DST} - ] - v = self.vertices.select(*v_cols).__repr__() - e = self.edges.select(*e_cols).__repr__() - - return f"GraphFrame(v:{v}, e:{e})" - - def cache(self) -> "GraphFrameConnect": - new_vertices = self._vertices.cache() - new_edges = self._edges.cache() - return GraphFrameConnect(new_vertices, new_edges) - - def persist( - self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY - ) -> "GraphFrameConnect": - new_vertices = self._vertices.persist(storageLevel=storageLevel) - new_edges = self._edges.persist(storageLevel=storageLevel) - return GraphFrameConnect(new_vertices, new_edges) - - def unpersist(self, blocking: bool = False) -> "GraphFrameConnect": - new_vertices = self._vertices.unpersist(blocking=blocking) - new_edges = self._edges.unpersist(blocking=blocking) - return GraphFrameConnect(new_vertices, new_edges) - - @property - def outDegrees(self) -> DataFrame: - return self._edges.groupBy(F.col(self.SRC).alias(self.ID)).agg( - F.count("*").alias("outDegree") - ) - - @property - def inDegrees(self) -> DataFrame: - return self._edges.groupBy(F.col(self.DST).alias(self.ID)).agg( - F.count("*").alias("inDegree") - ) - - @property - def degrees(self) -> DataFrame: - return ( - self._edges.select( - F.explode(F.array(F.col(self.SRC), F.col(self.DST))).alias(self.ID) - ) - .groupBy(self.ID) - .agg(F.count("*").alias("degree")) - ) - @property def triplets(self) -> DataFrame: @final @@ -395,16 +310,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - return _dataframe_from_plan( - Find(self._vertices, self._edges, pattern), self._spark - ) + return _dataframe_from_plan(Find(self._vertices, self._edges, pattern), self._spark) def filterVertices(self, condition: str | Column) -> "GraphFrameConnect": @final class FilterVertices(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, condition: str | Column - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: super().__init__(None) self.v = v self.e = e @@ -428,12 +339,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) # Exactly like in the scala-core new_edges = self._edges.join( - new_vertices.withColumn(self.SRC, F.col(self.ID)), - on=[self.SRC], + new_vertices.withColumn(self._SRC, F.col(self._ID)), + on=[self._SRC], how="left_semi", ).join( - new_vertices.withColumn(self.DST, F.col(self.ID)), - on=[self.DST], + new_vertices.withColumn(self._DST, F.col(self._ID)), + on=[self._DST], how="left_semi", ) return GraphFrameConnect(new_vertices, new_edges) @@ -441,9 +352,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def filterEdges(self, condition: str | Column) -> "GraphFrameConnect": @final class FilterEdges(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, condition: str | Column - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: super().__init__(None) self.v = v self.e = e @@ -455,9 +364,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self.v, self.e, session ) col_or_expr = make_column_or_expr(self.c, session) - graphframes_api_call.filter_edges.CopyFrom( - pb.FilterEdges(condition=col_or_expr) - ) + graphframes_api_call.filter_edges.CopyFrom(pb.FilterEdges(condition=col_or_expr)) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -530,9 +437,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.drop_isolated_vertices.CopyFrom( - pb.DropIsolatedVertices() - ) + graphframes_api_call.drop_isolated_vertices.CopyFrom(pb.DropIsolatedVertices()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -634,12 +539,8 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.aggregate_messages.CopyFrom( pb.AggregateMessages( agg_col=[make_column_or_expr(x, session) for x in self.agg_col], - send_to_src=[ - make_column_or_expr(x, session) for x in self.send2src - ], - send_to_dst=[ - make_column_or_expr(x, session) for x in self.send2dst - ], + send_to_src=[make_column_or_expr(x, session) for x in self.send2src], + send_to_dst=[make_column_or_expr(x, session) for x in self.send2dst], storage_level=storage_level_to_proto(self.storage_level), ) ) @@ -648,9 +549,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan if (len(sendToSrc) == 0) and (len(sendToDst) == 0): - raise ValueError( - "Either `sendToSrc`, `sendToDst`, or both have to be provided" - ) + raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") return _dataframe_from_plan( AggregateMessages( @@ -794,19 +693,17 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self._spark, ) - def _update_page_rank_edge_weights( - self, new_vertices: DataFrame - ) -> "GraphFrameConnect": + def _update_page_rank_edge_weights(self, new_vertices: DataFrame) -> "GraphFrameConnect": cols2select = self.edges.columns + ["weight"] new_edges = ( self._edges.join( - new_vertices.withColumn(self.SRC, F.col(self.ID)), - on=[self.SRC], + new_vertices.withColumn(self._SRC, F.col(self._ID)), + on=[self._SRC], how="inner", ) .join( - self.outDegrees.withColumn(self.SRC, F.col(self.ID)), - on=[self.SRC], + self.outDegrees.withColumn(self._SRC, F.col(self._ID)), + on=[self._SRC], how="inner", ) .withColumn("weight", F.col("pagerank") / F.col("outDegree")) @@ -849,9 +746,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: pb.PageRank( reset_probability=self.reset_prob, source_id=( - None - if self.source_id is None - else make_str_or_long_id(self.source_id) + None if self.source_id is None else make_str_or_long_id(self.source_id) ), max_iter=self.max_iter, tol=self.tol, @@ -911,9 +806,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.parallel_personalized_page_rank.CopyFrom( pb.ParallelPersonalizedPageRank( reset_probability=self.reset_prob, - source_ids=[ - make_str_or_long_id(raw_id) for raw_id in self.source_ids - ], + source_ids=[make_str_or_long_id(raw_id) for raw_id in self.source_ids], max_iter=self.max_iter, ) ) @@ -921,9 +814,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - assert sourceIds is not None and len(sourceIds) > 0, ( - "Source vertices Ids sourceIds must be provided" - ) + assert ( + sourceIds is not None and len(sourceIds) > 0 + ), "Source vertices Ids sourceIds must be provided" assert maxIter is not None, "Max number of iterations maxIter must be provided" new_vertices = _dataframe_from_plan( @@ -975,9 +868,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan return _dataframe_from_plan( - PowerIterationClustering( - self._vertices, self._edges, k, maxIter, weightCol - ), + PowerIterationClustering(self._vertices, self._edges, k, maxIter, weightCol), self._spark, ) @@ -1017,9 +908,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) graphframes_api_call.shortest_paths.CopyFrom( pb.ShortestPaths( - landmarks=[ - make_str_or_long_id(raw_id) for raw_id in self.landmarks - ], + landmarks=[make_str_or_long_id(raw_id) for raw_id in self.landmarks], algorithm=self.algorithm, use_local_checkpoints=self.use_local_checkpoints, checkpoint_interval=self.checkpoint_interval, @@ -1154,9 +1043,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def triangleCount(self, storage_level: StorageLevel) -> DataFrame: @final class TriangleCount(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, storage_level: StorageLevel - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, storage_level: StorageLevel) -> None: super().__init__(None) self.v = v self.e = e @@ -1168,14 +1055,10 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self.v, self.e, session ) graphframes_api_call.triangle_count.CopyFrom( - pb.TriangleCount( - storage_level=storage_level_to_proto(self.storage_level) - ) + pb.TriangleCount(storage_level=storage_level_to_proto(self.storage_level)) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan - return _dataframe_from_plan( - TriangleCount(self._vertices, self._edges), self._spark - ) + return _dataframe_from_plan(TriangleCount(self._vertices, self._edges), self._spark) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index b82d3aeab..b4c41979b 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -20,8 +20,8 @@ import warnings from typing import TYPE_CHECKING, Any -from pyspark.storagelevel import StorageLevel from pyspark.sql import functions as F +from pyspark.storagelevel import StorageLevel from pyspark.version import __version__ from typing_extensions import override @@ -72,6 +72,11 @@ class GraphFrame: >>> g = GraphFrame(v, e) """ + ID: str = ID + SRC: str = SRC + DST: str = DST + EDGE: str = EDGE + @staticmethod def _from_impl(impl: "GraphFrameClassic | GraphFrameConnect") -> "GraphFrame": return GraphFrame(impl.vertices, impl.edges) @@ -88,6 +93,24 @@ def __init__(self, v: DataFrame, e: DataFrame) -> None: vertex IDs and destination vertex IDs of edges, respectively. """ self._impl: "GraphFrameClassic | GraphFrameConnect" + if self.ID not in v.columns: + raise ValueError( + "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( + self.ID, ",".join(v.columns) + ) + ) + if self.SRC not in e.columns: + raise ValueError( + "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 + self.SRC, ",".join(e.columns) + ) + ) + if self.DST not in e.columns: + raise ValueError( + "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( # noqa: E501 + self.DST, ",".join(e.columns) + ) + ) if is_remote(): from graphframes.connect.graphframes_client import GraphFrameConnect @@ -101,7 +124,7 @@ def vertices(self) -> DataFrame: :class:`DataFrame` holding vertex information, with unique column "id" for vertex IDs. """ - return self._impl.vertices + return self._impl._vertices @property def edges(self) -> DataFrame: @@ -110,7 +133,7 @@ def edges(self) -> DataFrame: "dst" storing source vertex IDs and destination vertex IDs of edges, respectively. """ - return self._impl.edges + return self._impl._edges @property def nodes(self) -> DataFrame: @@ -119,27 +142,39 @@ def nodes(self) -> DataFrame: @override def __repr__(self) -> str: - return self._impl.__repr__() + # Exactly like in the scala core + v_cols = [self.ID] + [col for col in self._impl._vertices.columns if col != self.ID] + e_cols = [self.SRC, self.DST] + [ + col for col in self._impl._edges.columns if col not in {self.SRC, self.DST} + ] + v = self._impl._vertices.select(*v_cols).__repr__() + e = self._impl._edges.select(*e_cols).__repr__() + + return f"GraphFrame(v:{v}, e:{e})" def cache(self) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the default storage level. """ - return GraphFrame._from_impl(self._impl.cache()) + new_vertices = self._impl._vertices.cache() + new_edges = self._impl._edges.cache() + return GraphFrame(new_vertices, new_edges) - def persist( - self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY - ) -> "GraphFrame": + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the given storage level. """ - return GraphFrame._from_impl(self._impl.persist(storageLevel=storageLevel)) + new_vertices = self._impl._vertices.persist(storageLevel=storageLevel) + new_edges = self._impl._edges.persist(storageLevel=storageLevel) + return GraphFrame(new_vertices, new_edges) def unpersist(self, blocking: bool = False) -> "GraphFrame": """Mark the dataframe representation of vertices and edges of the graph as non-persistent, and remove all blocks for it from memory and disk. """ - return GraphFrame._from_impl(self._impl.unpersist(blocking=blocking)) + new_vertices = self._impl._vertices.unpersist(blocking=blocking) + new_edges = self._impl._edges.unpersist(blocking=blocking) + return GraphFrame(new_vertices, new_edges) @property def outDegrees(self) -> DataFrame: @@ -152,7 +187,9 @@ def outDegrees(self) -> DataFrame: :return: DataFrame with new vertices column "outDegree" """ - return self._impl.outDegrees + return self._impl._edges.groupBy(F.col(self.SRC).alias(self.ID)).agg( + F.count("*").alias("outDegree") + ) @property def inDegrees(self) -> DataFrame: @@ -165,7 +202,9 @@ def inDegrees(self) -> DataFrame: :return: DataFrame with new vertices column "inDegree" """ - return self._impl.inDegrees + return self._impl._edges.groupBy(F.col(self.DST).alias(self.ID)).agg( + F.count("*").alias("inDegree") + ) @property def degrees(self) -> DataFrame: @@ -178,7 +217,13 @@ def degrees(self) -> DataFrame: :return: DataFrame with new vertices column "degree" """ - return self._impl.degrees + return ( + self._impl._edges.select( + F.explode(F.array(F.col(self.SRC), F.col(self.DST))).alias(self.ID) + ) + .groupBy(self.ID) + .agg(F.count("*").alias("degree")) + ) @property def triplets(self) -> DataFrame: @@ -268,9 +313,7 @@ def detectingCycles( :return: Persisted DataFrame with all the cycles """ - return self._impl.detectingCycles( - checkpoint_interval, use_local_checkpoints, storage_level - ) + return self._impl.detectingCycles(checkpoint_interval, use_local_checkpoints, storage_level) def bfs( self, @@ -359,9 +402,7 @@ def aggregateMessages( raise TypeError("At least one aggregation column should be provided!") if (len(sendToSrc) == 0) and (len(sendToDst) == 0): - raise ValueError( - "Either `sendToSrc`, `sendToDst`, or both have to be provided" - ) + raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") return self._impl.aggregateMessages( aggCol=aggCol, sendToSrc=sendToSrc, @@ -566,7 +607,7 @@ def validate( :raises ValueError: if there are any inconsistencies in the graph, such as duplicate vertices, mismatched vertices between edges and vertex DataFrames or missing connections. - """ + """ # noqa: E501 persisted_vertices = self.vertices.persist(intermediate_storage_level) row = persisted_vertices.select(F.count_distinct(F.col(ID))).first() assert row is not None # for type checker @@ -574,9 +615,9 @@ def validate( assert isinstance(count_distinct_vertices, int) # for type checker total_count_vertices = persisted_vertices.count() if count_distinct_vertices != total_count_vertices: - raise ValueError( - f"Graph contains ({total_count_vertices - count_distinct_vertices}) duplicate vertices." - ) + _msg = "Graph contains ({}) duplicate vertices." + + raise ValueError(_msg.format(total_count_vertices - count_distinct_vertices)) if check_vertices: vertices_set_from_edges = ( self.edges.select(F.col(SRC).alias(ID)) @@ -586,20 +627,18 @@ def validate( ) count_vertices_from_edges = vertices_set_from_edges.count() if count_vertices_from_edges > count_distinct_vertices: - raise ValueError( - f"Graph is inconsistent: edges has {count_vertices_from_edges} " - + f"vertices, but vertices has {count_distinct_vertices} vertices." - ) + _msg = "Graph is inconsistent: edges has {count_vertices_from_edges} " + _msg += "vertices, but vertices has {} vertices." + raise ValueError(_msg.format(count_distinct_vertices)) combined = vertices_set_from_edges.join(self.vertices, ID, "left_anti") count_of_bad_vertices = combined.count() if count_of_bad_vertices > 0: - raise ValueError( - "Vertices DataFrame does not contain all edges src/dst. " - + f"Found {count_of_bad_vertices} edges src/dst that are not in the vertices DataFrame." - ) - _ = persisted_vertices.unpersist() + _msg = "Vertices DataFrame does not contain all edges src/dst. " + _msg += "Found {} edges src/dst that are not in the vertices DataFrame." + raise ValueError(_msg.format(count_of_bad_vertices)) _ = vertices_set_from_edges.unpersist() + _ = persisted_vertices.unpersist() def as_undirected(self) -> "GraphFrame": """ diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index 2dfdf7888..2f2049184 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -17,16 +17,18 @@ from typing import TYPE_CHECKING, final -from graphframes.classic.utils import storage_level_to_jvm - from pyspark.ml.wrapper import JavaWrapper from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col +from graphframes.classic.utils import storage_level_to_jvm + if TYPE_CHECKING: - from graphframes.classic.graphframe import GraphFrame from pyspark.sql import Column from pyspark.storagelevel import StorageLevel + from typing_extensions import Self + + from graphframes.classic.graphframe import GraphFrame @final @@ -60,9 +62,7 @@ def __init__(self, graph: "GraphFrame") -> None: super(Pregel, self).__init__() self.graph = graph - self._java_obj = self._new_java_obj( - "org.graphframes.lib.Pregel", graph._jvm_graph - ) + self._java_obj = self._new_java_obj("org.graphframes.lib.Pregel", graph._jvm_graph) def setMaxIter(self, value: int) -> "Pregel": """Sets the max number of iterations (default: 10). @@ -115,9 +115,7 @@ def withVertexColumn( aggregated message column using :func:`msg`. If the vertex received no messages, the message column would be null. """ # noqa: E501 - self._java_obj.withVertexColumn( - colName, initialExpr._jc, updateAfterAggMsgsExpr._jc - ) + self._java_obj.withVertexColumn(colName, initialExpr._jc, updateAfterAggMsgsExpr._jc) return self def sendMsgToSrc(self, msgExpr: Column) -> "Pregel": @@ -173,7 +171,7 @@ def setStopIfAllNonActiveVertices(self, value: bool) -> Self: condition is not realistically reachable but set, it will just slow down the algorithm. :param value: the boolean value. - """ + """ # noqa: E501 self._java_obj.setStopIfAllNonActiveVertices(value) return self From a4b600440b415e002267e9d9c47edfb2a5083426 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 4 Oct 2025 10:23:20 +0200 Subject: [PATCH 08/17] fix import --- python/graphframes/classic/graphframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index f9191f8cb..a3b5eff27 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -19,7 +19,7 @@ from typing import final from py4j.java_gateway import JavaObject -from pyspark.core.context import SparkContext +from pyspark import SparkContext from pyspark.sql import SparkSession from pyspark.sql.classic.column import Column, _to_seq from pyspark.sql.classic.dataframe import DataFrame From 5d2a60f6282664071031f88c21d66e0a99cdf0b3 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 10:20:16 +0200 Subject: [PATCH 09/17] updates --- python/dev/build_jar.py | 8 +- python/graphframes/classic/graphframe.py | 10 +- python/graphframes/classic/utils.py | 5 +- python/graphframes/graphframe.py | 82 ++++- python/graphframes/lib/aggregate_messages.py | 30 +- python/graphframes/lib/pregel.py | 18 +- python/run-tests.sh | 75 ---- python/tests/conftest.py | 29 +- python/tests/test_graphframes.py | 358 ++++++++----------- 9 files changed, 281 insertions(+), 334 deletions(-) delete mode 100755 python/run-tests.sh diff --git a/python/dev/build_jar.py b/python/dev/build_jar.py index ff081ffd7..94f8bbc0e 100644 --- a/python/dev/build_jar.py +++ b/python/dev/build_jar.py @@ -4,18 +4,20 @@ from pathlib import Path -def build(spark_versions: Sequence[str] = ["3.5.5"]): +def build(spark_versions: Sequence[str] = ["4.0.1"]): for spark_version in spark_versions: print("Building GraphFrames JAR...") print(f"SPARK_VERSION: {spark_version[:3]}") assert spark_version[:3] in {"3.5", "4.0"}, "Unsupported spark version!" project_root = Path(__file__).parent.parent.parent - sbt_executable = project_root.joinpath("build").joinpath("sbt").absolute().__str__() + sbt_executable = ( + project_root.joinpath("build").joinpath("sbt").absolute().__str__() + ) sbt_build_command = [ sbt_executable, f"-Dspark.version={spark_version}", - "package" + "package", ] sbt_build = subprocess.Popen( sbt_build_command, diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index a3b5eff27..3686271f9 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -112,7 +112,7 @@ def detectingCycles( .run() ) - return _from_java_gf(jdf, self._spark) + return DataFrame(jdf, self._spark) def dropIsolatedVertices(self) -> "GraphFrame": jdf = self._jvm_graph.dropIsolatedVertices() @@ -318,8 +318,12 @@ def svdPlusPlus( v = DataFrame(jdf, self._spark) return (v, loss) - def triangleCount(self) -> DataFrame: - jdf = self._jvm_graph.triangleCount().run() + def triangleCount(self, storage_level: StorageLevel) -> DataFrame: + jdf = ( + self._jvm_graph.triangleCount() + .setIntermediateStorageLevel(storage_level_to_jvm(storage_level, self._spark)) + .run() + ) return DataFrame(jdf, self._spark) def powerIterationClustering( diff --git a/python/graphframes/classic/utils.py b/python/graphframes/classic/utils.py index 863ca0c5f..646c1afa4 100644 --- a/python/graphframes/classic/utils.py +++ b/python/graphframes/classic/utils.py @@ -1,9 +1,6 @@ from py4j.java_gateway import JavaObject +from pyspark.sql import SparkSession from pyspark.storagelevel import StorageLevel -from typing_extensions import TYPE_CHECKING - -if TYPE_CHECKING: - from pyspark.sql.classic.dataframe import SparkSession def storage_level_to_jvm(storage_level: StorageLevel, spark: SparkSession) -> JavaObject: diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index b4c41979b..e7b6f82b2 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -79,7 +79,7 @@ class GraphFrame: @staticmethod def _from_impl(impl: "GraphFrameClassic | GraphFrameConnect") -> "GraphFrame": - return GraphFrame(impl.vertices, impl.edges) + return GraphFrame(impl._vertices, impl._edges) def __init__(self, v: DataFrame, e: DataFrame) -> None: """ @@ -447,16 +447,42 @@ def connectedComponents( storage_level=storage_level, ) - def labelPropagation(self, maxIter: int) -> DataFrame: + def labelPropagation( + self, + maxIter: int, + algorithm: str = "graphx", + use_local_checkpoints: bool = False, + checkpoint_interval: int = 2, + storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, + ) -> DataFrame: """ Runs static label propagation for detecting communities in networks. See Scala documentation for more details. :param maxIter: the number of iterations to be performed - :return: DataFrame with new vertices column "label" - """ - return self._impl.labelPropagation(maxIter=maxIter) + :param algorithm: implementation to use, posible values are "graphframes" and "graphx"; + "graphx" is faster for small-medium sized graphs, + "graphframes" requires less amount of memory + :param use_local_checkpoints: should local checkpoints be used, default false; + local checkpoints are faster and does not require to set + a persistent checkpointDir; from the other side, local + checkpoints are less reliable and require executors to have + big enough local disks. + :checkpoint_interval: How often should the intermediate result be checkpointed; + Using big value here may tend to huge logical plan growth due + to the iterative nature of the algorithm. + :param storage_level: storage level for both intermediate and final dataframes. + + :return: Persisted DataFrame with new vertices column "label" + """ + return self._impl.labelPropagation( + maxIter=maxIter, + algorithm=algorithm, + use_local_checkpoints=use_local_checkpoints, + checkpoint_interval=checkpoint_interval, + storage_level=storage_level, + ) def pageRank( self, @@ -511,16 +537,42 @@ def parallelPersonalizedPageRank( ) ) - def shortestPaths(self, landmarks: list[Any]) -> DataFrame: + def shortestPaths( + self, + landmarks: list[str | int], + algorithm: str = "graphx", + use_local_checkpoints: bool = False, + checkpoint_interval: int = 2, + storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK_DESER, + ) -> DataFrame: """ Runs the shortest path algorithm from a set of landmark vertices in the graph. See Scala documentation for more details. :param landmarks: a set of one or more landmarks - :return: DataFrame with new vertices column "distances" - """ - return self._impl.shortestPaths(landmarks=landmarks) + :param algorithm: implementation to use, posible values are "graphframes" and "graphx"; + "graphx" is faster for small-medium sized graphs, + "graphframes" requires less amount of memory + :param use_local_checkpoints: should local checkpoints be used, default false; + local checkpoints are faster and does not require to set + a persistent checkpointDir; from the other side, local + checkpoints are less reliable and require executors to have + big enough local disks. + :checkpoint_interval: How often should the intermediate result be checkpointed; + Using big value here may tend to huge logical plan growth due + to the iterative nature of the algorithm. + :param storage_level: storage level for both intermediate and final dataframes. + + :return: persistent DataFrame with new vertices column "distances" + """ # noqa: E501 + return self._impl.shortestPaths( + landmarks=landmarks, + algorithm=algorithm, + use_local_checkpoints=use_local_checkpoints, + checkpoint_interval=checkpoint_interval, + storage_level=storage_level, + ) def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: """ @@ -562,15 +614,21 @@ def svdPlusPlus( gamma7=gamma7, ) - def triangleCount(self) -> DataFrame: + def triangleCount(self, storage_level: StorageLevel) -> DataFrame: """ Counts the number of triangles passing through each vertex in this graph. + This impementation is based on the computing the intersection of + vertices neighborhoods. It requires to collect the whole neighborhood of + each vertex. It may fail because of memory errors on graphs with power law + degrees distribution (graphs with a few very high-degree vertices). Consider + edges sampling for that case to get an approximate count of trangles. - See Scala documentation for more details. + :param storage_level: storage level that is used for both + intermediate and final dataframes. :return: DataFrame with new vertex column "count" """ - return self._impl.triangleCount() + return self._impl.triangleCount(storage_level=storage_level) def powerIterationClustering( self, k: int, maxIter: int, weightCol: str | None = None diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py index a164aeca6..fce2cc478 100644 --- a/python/graphframes/lib/aggregate_messages.py +++ b/python/graphframes/lib/aggregate_messages.py @@ -16,31 +16,45 @@ # +from typing import Any + from pyspark.sql import Column from pyspark.sql import functions as F -from graphframes import graphframe + +class _ClassProperty: + """Custom read-only class property descriptor. + + The underlying method should take the class as the sole argument. + """ + + def __init__(self, f: Any) -> None: + self.f = f + self.__doc__ = f.__doc__ + + def __get__(self, instance: Any, owner: type) -> Any: + return self.f(owner) class AggregateMessages: """Collection of utilities usable with :meth:`graphframes.GraphFrame.aggregateMessages()`.""" - @staticmethod + @_ClassProperty def src() -> Column: """Reference for source column, used for specifying messages.""" - return F.col(graphframe.SRC) + return F.col("src") - @staticmethod + @_ClassProperty def dst() -> Column: """Reference for destination column, used for specifying messages.""" - return F.col(graphframe.DST) + return F.col("dst") - @staticmethod + @_ClassProperty def edge() -> Column: """Reference for edge column, used for specifying messages.""" - return F.col(graphframe.EDGE) + return F.col("edge") - @staticmethod + @_ClassProperty def msg() -> Column: """Reference for message column, used for specifying aggregation function.""" return F.col("MSG") diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index 2f2049184..205b347e9 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -15,21 +15,16 @@ # limitations under the License. # -from typing import TYPE_CHECKING, final +from typing import final from pyspark.ml.wrapper import JavaWrapper -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import Column, DataFrame, SparkSession from pyspark.sql.functions import col +from pyspark.storagelevel import StorageLevel +from typing_extensions import Self from graphframes.classic.utils import storage_level_to_jvm -if TYPE_CHECKING: - from pyspark.sql import Column - from pyspark.storagelevel import StorageLevel - from typing_extensions import Self - - from graphframes.classic.graphframe import GraphFrame - @final class Pregel(JavaWrapper): @@ -58,7 +53,7 @@ class Pregel(JavaWrapper): :param graph: a :class:`graphframes.GraphFrame` object holding a graph with vertices and edges stored as DataFrames. """ # noqa: E501 - def __init__(self, graph: "GraphFrame") -> None: + def __init__(self, graph: "GraphFrame") -> None: # noqa: F821 super(Pregel, self).__init__() self.graph = graph @@ -230,8 +225,9 @@ def setIntermediateStorageLevel(self, storage_level: StorageLevel) -> Self: :param storage_level: storage level to use. """ # noqa: E501 self._java_obj.setIntermediateStorageLevel( - storage_level_to_jvm(storage_level, self.graph.vertices.sparkSession) + storage_level_to_jvm(storage_level, self.graph._spark) ) + return self def run(self) -> DataFrame: """Runs the defined Pregel algorithm. diff --git a/python/run-tests.sh b/python/run-tests.sh deleted file mode 100755 index 93a102168..000000000 --- a/python/run-tests.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Return on any failure -set -e - -# assumes run from python/ directory -if [ -z "$SPARK_HOME" ]; then - echo 'You need to set $SPARK_HOME to run these tests.' >&2 - exit 1 -fi - -# Honor the choice of python driver -if [ -z "$PYSPARK_PYTHON" ]; then - PYSPARK_PYTHON=`which python` -fi -# Override the python driver version as well to make sure we are in sync in the tests. -export PYSPARK_DRIVER_PYTHON=$PYSPARK_PYTHON -python_major=$($PYSPARK_PYTHON -c 'import sys; print(".".join(map(str, sys.version_info[:1])))') - -echo $pyver - -LIBS="" -for lib in "$SPARK_HOME/python/lib"/*zip ; do - LIBS=$LIBS:$lib -done - -# The current directory of the script. -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -a=( ${SCALA_VERSION//./ } ) -scala_version_major_minor="${a[0]}.${a[1]}" -echo "List of assembly jars found, the last one will be used:" -assembly_path="$DIR/../target/scala-$scala_version_major_minor" -echo `ls $assembly_path/graphframes-assembly*.jar` -JAR_PATH="" -for assembly in $assembly_path/graphframes-assembly*.jar ; do - JAR_PATH=$assembly -done - -export PYSPARK_SUBMIT_ARGS="--driver-memory 2g --executor-memory 2g --jars $JAR_PATH pyspark-shell " - -export PYTHONPATH=$PYTHONPATH:$SPARK_HOME/python:$LIBS:. - -export PYTHONPATH=$PYTHONPATH:graphframes - - -# Run test suites -poetry run python -m "pytest" -v $DIR/graphframes/tests.py 2>&1 | grep -vE "INFO (ParquetOutputFormat|SparkContext|ContextCleaner|ShuffleBlockFetcherIterator|MapOutputTrackerMaster|TaskSetManager|Executor|MemoryStore|CacheManager|BlockManager|DAGScheduler|PythonRDD|TaskSchedulerImpl|ZippedPartitionsRDD2)"; - -# Exit immediately if the tests fail. -# Since we pipe to remove the output, we need to use some horrible BASH features: -# http://stackoverflow.com/questions/1221833/bash-pipe-output-and-capture-exit-status -test ${PIPESTATUS[0]} -eq 0 || exit 1; - -# Run doc tests -cd "$DIR" - -poetry run python -u ./graphframes/graphframe.py "$@" diff --git a/python/tests/conftest.py b/python/tests/conftest.py index c4c87ca1b..91d83a96d 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os import pathlib import tempfile -from typing import Optional, Tuple import warnings import pytest @@ -18,11 +19,12 @@ def is_remote() -> bool: return False + spark_major_version = __version__[:1] scala_version = os.environ.get("SCALA_VERSION", "2.12" if __version__ < "4" else "2.13") -def get_gf_jar_locations() -> Tuple[str, str, str]: +def get_gf_jar_locations() -> tuple[str, str, str]: """ Returns a location of the GraphFrames JAR and GraphFrames Connect JAR. @@ -34,9 +36,9 @@ def get_gf_jar_locations() -> Tuple[str, str, str]: core_dir = project_root / "core" / "target" / f"scala-{scala_version}" connect_dir = project_root / "connect" / "target" / f"scala-{scala_version}" - graphx_jar: Optional[str] = None - core_jar: Optional[str] = None - connect_jar: Optional[str] = None + graphx_jar: str | None = None + core_jar: str | None = None + connect_jar: str | None = None for pp in graphx_dir.glob(f"graphframes-graphx-spark{spark_major_version}*.jar"): assert isinstance(pp, pathlib.PosixPath) # type checking @@ -64,7 +66,7 @@ def get_gf_jar_locations() -> Tuple[str, str, str]: raise ValueError( f"Failed to find graphframes connect jar for Spark {spark_major_version} in {connect_dir}" ) - + return core_jar, connect_jar, graphx_jar @@ -76,21 +78,26 @@ def spark(): (core_jar, connect_jar, graphx_jar) = get_gf_jar_locations() with tempfile.TemporaryDirectory() as tmp_dir: - builder = (SparkSession.Builder() + builder = ( + SparkSession.Builder() .appName("GraphFramesTest") .config("spark.sql.shuffle.partitions", 4) .config("spark.checkpoint.dir", tmp_dir) .config("spark.jars", f"{core_jar},{connect_jar},{graphx_jar}") + .config("spark.driver.memory", "4g") ) if spark_major_version == "3": # Spark 3 does not include connect by default - builder = builder.config("spark.jars.packages", f"org.apache.spark:spark-connect_{scala_version}:{__version__}") + builder = builder.config( + "spark.jars.packages", + f"org.apache.spark:spark-connect_{scala_version}:{__version__}", + ) if is_remote(): - builder = (builder - .remote("local[4]") - .config("spark.connect.extensions.relation.classes", "org.apache.spark.sql.graphframes.GraphFramesConnect") + builder = builder.remote("local[4]").config( + "spark.connect.extensions.relation.classes", + "org.apache.spark.sql.graphframes.GraphFramesConnect", ) else: builder = builder.master("local[4]") diff --git a/python/tests/test_graphframes.py b/python/tests/test_graphframes.py index a6a4b03a9..4da3fc59b 100644 --- a/python/tests/test_graphframes.py +++ b/python/tests/test_graphframes.py @@ -16,17 +16,49 @@ # +from dataclasses import dataclass +from pyspark.storagelevel import StorageLevel import pytest -from pyspark.sql import functions as sqlfunctions -from pyspark.sql.utils import is_remote +from pyspark.sql import DataFrame, SparkSession, functions as sqlfunctions from graphframes.classic.graphframe import _from_java_gf from graphframes.examples import BeliefPropagation, Graphs from graphframes.graphframe import GraphFrame -from graphframes.lib import AggregateMessages as AM -def test_construction(spark, local_g): +@dataclass +class PregelArguments: + algorithm: str + use_local_checkpoints: bool + checkpoint_interval: int + storage_level: StorageLevel + + +PREGEL_ARGUMENTS = [ + PregelArguments("graphframes", True, 5, StorageLevel.MEMORY_AND_DISK), + PregelArguments("graphx", False, 3, StorageLevel.DISK_ONLY), + PregelArguments("graphframes", False, 7, StorageLevel.MEMORY_ONLY), + PregelArguments("graphframes", True, 1, StorageLevel.DISK_ONLY_3), +] +PREGEL_IDS: list[str] = [ + "graphframes,local,5,MEMORY_AND_DISK", + "graphx,global,3,DISK_ONLY", + "graphframes,global,7,MEMORY_ONLY", + "graphframes,local,1,DISK_ONLY_3", +] +STORAGE_LEVELS = [ + StorageLevel.MEMORY_AND_DISK_2, + StorageLevel.DISK_ONLY, + StorageLevel.MEMORY_ONLY, +] +STORAGE_LEVELS_IDS = [ + "MEMORY_AND_DISK_2", + "DISK_ONLY", + "MEMORY_ONLY", +] + + +def test_construction(spark: SparkSession, local_g: GraphFrame): vertexIDs = [row[0] for row in local_g.vertices.select("id").collect()] assert sorted(vertexIDs) == [1, 2, 3] @@ -48,15 +80,15 @@ def test_construction(spark, local_g): [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"] ) with pytest.raises(ValueError): - GraphFrame(v_invalid, e_invalid) + _ = GraphFrame(v_invalid, e_invalid) -def test_cache(local_g): - local_g.cache() - local_g.unpersist() +def test_cache(local_g: GraphFrame): + _ = local_g.cache() + _ = local_g.unpersist() -def test_degrees(local_g): +def test_degrees(local_g: GraphFrame): outDeg = local_g.outDegrees assert set(outDeg.columns) == {"id", "outDegree"} inDeg = local_g.inDegrees @@ -65,13 +97,13 @@ def test_degrees(local_g): assert set(deg.columns) == {"id", "degree"} -def test_motif_finding(local_g): +def test_motif_finding(local_g: GraphFrame): motifs = local_g.find("(a)-[e]->(b)") assert motifs.count() == 3 assert set(motifs.columns) == {"a", "e", "b"} -def test_filterVertices(local_g): +def test_filterVertices(local_g: GraphFrame): conditions = ["id < 3", local_g.vertices.id < 3] expected_v = [(1, "A"), (2, "B")] expected_e = [(1, 2, "love"), (2, 1, "hate")] @@ -85,7 +117,7 @@ def test_filterVertices(local_g): assert set(e2) == set(expected_e) -def test_filterEdges(local_g): +def test_filterEdges(local_g: GraphFrame): conditions = ["dst > 2", local_g.edges.dst > 2] expected_v = [(1, "A"), (2, "B"), (3, "C")] expected_e = [(2, 3, "follow")] @@ -99,7 +131,7 @@ def test_filterEdges(local_g): assert set(e2) == set(expected_e) -def test_dropIsolatedVertices(local_g): +def test_dropIsolatedVertices(local_g: GraphFrame): g2 = local_g.filterEdges("dst > 2").dropIsolatedVertices() v2 = g2.vertices.select("id", "name").collect() e2 = g2.edges.select("src", "dst", "action").collect() @@ -111,7 +143,7 @@ def test_dropIsolatedVertices(local_g): assert set(e2) == set(expected_e) -def test_bfs(local_g): +def test_bfs(local_g: GraphFrame): paths = local_g.bfs("name='A'", "name='C'") assert paths is not None assert paths.count() == 1 @@ -127,7 +159,7 @@ def test_bfs(local_g): assert paths3.count() == 0 -def test_power_iteration_clustering(spark): +def test_power_iteration_clustering(spark: SparkSession): vertices = [ (1, 0, 0.5), (2, 0, 0.5), @@ -150,18 +182,15 @@ def test_power_iteration_clustering(spark): v=spark.createDataFrame(edges).toDF("id"), e=spark.createDataFrame(vertices).toDF("src", "dst", "weight"), ) + clusters_df = g.powerIterationClustering(k=2, maxIter=40, weightCol="weight") - clusters = [ - r["cluster"] - for r in g.powerIterationClustering(k=2, maxIter=40, weightCol="weight") - .sort("id") - .collect() - ] + clusters = [r["cluster"] for r in clusters_df.sort("id").collect()] assert clusters == [0, 0, 0, 0, 1, 0] + _ = clusters_df.unpersist() -def test_page_rank(spark): +def test_page_rank(spark: SparkSession): edges = spark.createDataFrame( [ [0, 1], @@ -174,13 +203,13 @@ def test_page_rank(spark): ], ["src", "dst"], ) - edges.cache() + _ = edges.cache() vertices = spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) numVertices = vertices.count() vertices = GraphFrame(vertices, edges).outDegrees - vertices.toPandas().head() - vertices.cache() + _ = vertices.toPandas().head() + _ = vertices.cache() # Construct a new GraphFrame with the updated vertices DataFrame. graph = GraphFrame(vertices, edges) @@ -206,8 +235,11 @@ def test_page_rank(spark): # Compare each result with its expected value using a tolerance of 1e-3. for a, b in zip(result, expected): assert a == pytest.approx(b, abs=1e-3) + _ = ranks.unpersist() + -def test_pregel_early_stopping(spark): +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +def test_pregel_early_stopping(spark: SparkSession, args: PregelArguments): edges = spark.createDataFrame( [ [0, 1], @@ -220,20 +252,24 @@ def test_pregel_early_stopping(spark): ], ["src", "dst"], ) - edges.cache() + _ = edges.cache() vertices = spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) numVertices = vertices.count() vertices = GraphFrame(vertices, edges).outDegrees - vertices.toPandas().head() - vertices.cache() + _ = vertices.toPandas().head() + _ = vertices.cache() # Construct a new GraphFrame with the updated vertices DataFrame. graph = GraphFrame(vertices, edges) alpha = 0.15 pregel = graph.pregel ranks = ( - graph.pregel.setMaxIter(5).setEarlyStopping(True) + graph.pregel.setMaxIter(5) + .setEarlyStopping(True) + .setUseLocalCheckpoints(args.use_local_checkpoints) + .setIntermediateStorageLevel(args.storage_level) + .setCheckpointInterval(args.checkpoint_interval) .withVertexColumn( "rank", sqlfunctions.lit(1.0 / numVertices), @@ -252,129 +288,75 @@ def test_pregel_early_stopping(spark): # Compare each result with its expected value using a tolerance of 1e-3. for a, b in zip(result, expected): assert a == pytest.approx(b, abs=1e-3) + _ = ranks.unpersist() -def _hasCols(graph, vcols=[], ecols=[]): + +def _hasCols(graph: GraphFrame, vcols: list[str] = [], ecols: list[str] = []): for c in vcols: assert c in graph.vertices.columns, f"Vertex DataFrame missing column: {c}" for c in ecols: assert c in graph.edges.columns, f"Edge DataFrame missing column: {c}" -def _df_hasCols(df, vcols=[]): +def _df_hasCols(df: DataFrame, vcols: list[str] = []): for c in vcols: assert c in df.columns, f"DataFrame missing column: {c}" -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_aggregate_messages(examples, spark): - g = _from_java_gf(getattr(examples, "friends")(), spark) - # For each user, sum the ages of the adjacent users, - # plus 1 for the src's sum if the edge is "friend". - sendToSrc = AM.dst["age"] + sqlfunctions.when( - AM.edge["relationship"] == "friend", sqlfunctions.lit(1) - ).otherwise(0) - sendToDst = AM.src["age"] - agg = g.aggregateMessages( - sqlfunctions.sum(AM.msg).alias("summedAges"), - sendToSrc=sendToSrc, - sendToDst=sendToDst, - ) - # Run the aggregation again using SQL expressions as Strings. - agg2 = g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc="(dst['age'] + CASE WHEN (edge['relationship'] = 'friend') THEN 1 ELSE 0 END)", # noqa: E501 - sendToDst="src['age']", - ) - # Build mappings from id to the aggregated message. - aggMap = {row.id: row.summedAges for row in agg.select("id", "summedAges").collect()} - agg2Map = {row.id: row.summedAges for row in agg2.select("id", "summedAges").collect()} - # Compute the expected aggregation via brute force. - user2age = {row.id: row.age for row in g.vertices.select("id", "age").collect()} - trueAgg = {} - for src, dst, rel in g.edges.select("src", "dst", "relationship").collect(): - trueAgg[src] = trueAgg.get(src, 0) + user2age[dst] + (1 if rel == "friend" else 0) - trueAgg[dst] = trueAgg.get(dst, 0) + user2age[src] - # Verify both aggregations match the expected results. - assert aggMap == trueAgg, f"aggMap {aggMap} does not equal expected {trueAgg}" - assert agg2Map == trueAgg, f"agg2Map {agg2Map} does not equal expected {trueAgg}" - # Check that passing a wrong type for messages raises a TypeError. - with pytest.raises(TypeError): - g.aggregateMessages("sum(MSG) AS `summedAges`", sendToSrc=object(), sendToDst="src['age']") - with pytest.raises(TypeError): - g.aggregateMessages("sum(MSG) AS `summedAges`", sendToSrc=dst["age"], sendToDst=object()) - - -def test_connected_components(spark): +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +@pytest.mark.parametrize( + "cc_args", + [(-1, True), (10000, True), (-1, False), (10000, False)], + ids=["aqe,local", "skewed,local", "aqe,checkpoints", "skewed,checkpoints"], +) +def test_connected_components( + spark: SparkSession, args: PregelArguments, cc_args: tuple[int, bool] +): v = spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) e = spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter("src > 10") v = spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) e = spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter("src > 10") g = GraphFrame(v, e) - comps = g.connectedComponents() + comps = g.connectedComponents( + algorithm=args.algorithm, + checkpointInterval=args.checkpoint_interval, + use_local_checkpoints=args.use_local_checkpoints, + storage_level=args.storage_level, + broadcastThreshold=cc_args[0], + useLabelsAsComponents=cc_args[1], + ) _df_hasCols(comps, vcols=["id", "component", "vattr", "gender"]) assert comps.count() == 1 - - -def test_connected_components2(spark): + _ = comps.unpersist() + + +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +@pytest.mark.parametrize( + "cc_args", + [(-1, True), (10000, True), (-1, False), (10000, False)], + ids=["aqe,local", "skewed,local", "aqe,checkpoints", "skewed,checkpoints"], +) +def test_connected_components2( + spark: SparkSession, args: PregelArguments, cc_args: tuple[int, bool] +): v = spark.createDataFrame([(0, "a0", "b0"), (1, "a1", "b1")], ["id", "A", "B"]) e = spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) g = GraphFrame(v, e) - comps = g.connectedComponents() + comps = g.connectedComponents( + algorithm=args.algorithm, + checkpointInterval=args.checkpoint_interval, + use_local_checkpoints=args.use_local_checkpoints, + storage_level=args.storage_level, + broadcastThreshold=cc_args[0], + useLabelsAsComponents=cc_args[1], + ) _df_hasCols(comps, vcols=["id", "component", "A", "B"]) assert comps.count() == 2 + _ = comps.unpersist() -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_connected_components_friends(examples, spark): - g = _from_java_gf(getattr(examples, "friends")(), spark) - comps_tests = [ - g.connectedComponents(), - g.connectedComponents(broadcastThreshold=1), - g.connectedComponents(checkpointInterval=0), - g.connectedComponents(checkpointInterval=10), - g.connectedComponents(algorithm="graphx"), - g.connectedComponents(useLabelsAsComponents=True), - ] - for c in comps_tests: - assert c.groupBy("component").count().count() == 2 - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_label_progagation(examples, spark): - n = 5 - g = _from_java_gf(getattr(examples, "twoBlobs")(n), spark) - labels = g.labelPropagation(maxIter=4 * n) - labels1 = labels.filter("id < 5").select("label").collect() - all1 = {row.label for row in labels1} - assert len(all1) == 1 - labels2 = labels.filter("id >= 5").select("label").collect() - all2 = {row.label for row in labels2} - assert len(all2) == 1 - assert all1 != all2 - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_page_rank_2(examples, spark): - n = 100 - g = _from_java_gf(getattr(examples, "star")(n), spark) - resetProb = 0.15 - errorTol = 1.0e-5 - pr = g.pageRank(resetProb, tol=errorTol) - _hasCols(pr, vcols=["id", "pagerank"], ecols=["src", "dst", "weight"]) - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_parallel_personalized_page_rank(examples, spark): - n = 100 - g = _from_java_gf(getattr(examples, "star")(n), spark) - resetProb = 0.15 - maxIter = 15 - sourceIds = [1, 2, 3, 4] - pr = g.parallelPersonalizedPageRank(resetProb, sourceIds=sourceIds, maxIter=maxIter) - _hasCols(pr, vcols=["id", "pageranks"], ecols=["src", "dst", "weight"]) - - -def test_shortest_paths(spark): +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +def test_shortest_paths(spark: SparkSession, args: PregelArguments): edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] # Create bidirectional edges. all_edges = [z for (a, b) in edges for z in [(a, b), (b, a)]] @@ -383,97 +365,59 @@ def test_shortest_paths(spark): edgesDF = spark.createDataFrame(all_edges, ["src", "dst"]) vertices = spark.createDataFrame([(i,) for i in range(1, 7)], ["id"]) g = GraphFrame(vertices, edgesDF) - landmarks = [1, 4] - v2 = g.shortestPaths(landmarks) + landmarks: list[str | int] = [1, 4] + v2 = g.shortestPaths( + landmarks=landmarks, + algorithm=args.algorithm, + use_local_checkpoints=args.use_local_checkpoints, + checkpoint_interval=args.checkpoint_interval, + storage_level=args.storage_level, + ) _df_hasCols(v2, vcols=["id", "distances"]) + _ = v2.unpersist() -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_svd_plus_plus(examples, spark): - g = _from_java_gf(getattr(examples, "ALSSyntheticData")(), spark) - (v2, cost) = g.svdPlusPlus() - _df_hasCols(v2, vcols=["id", "column1", "column2", "column3", "column4"]) - - -def test_strongly_connected_components(spark): +def test_strongly_connected_components(spark: SparkSession): # Simple island test vertices = spark.createDataFrame([(i,) for i in range(1, 6)], ["id"]) edges = spark.createDataFrame([(7, 8)], ["src", "dst"]) g = GraphFrame(vertices, edges) c = g.stronglyConnectedComponents(5) for row in c.collect(): - assert ( - row.id == row.component - ), f"Vertex {row.id} not equal to its component {row.component}" + assert row.id == row.component, ( + f"Vertex {row.id} not equal to its component {row.component}" + ) + _ = c.unpersist() -def test_triangle_counts(spark): +@pytest.mark.parametrize("storage_level", STORAGE_LEVELS, ids=STORAGE_LEVELS_IDS) +def test_triangle_counts(spark: SparkSession, storage_level: StorageLevel): edges = spark.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) vertices = spark.createDataFrame([(0,), (1,), (2,)], ["id"]) g = GraphFrame(vertices, edges) - c = g.triangleCount() + c = g.triangleCount(storage_level=storage_level) for row in c.select("id", "count").collect(): - assert row.asDict()["count"] == 1, f"Triangle count for vertex {row.id} is not 1" - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_mutithreaded_sparksession_usage(spark): - # Test that the GraphFrame API works correctly from multiple threads. - localVertices = [(1, "A"), (2, "B"), (3, "C")] - localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] - v = spark.createDataFrame(localVertices, ["id", "name"]) - e = spark.createDataFrame(localEdges, ["src", "dst", "action"]) - - exc = None - - def run_graphframe() -> None: - nonlocal exc - try: - GraphFrame(v, e) - except Exception as _e: - exc = _e - - import threading - - thread = threading.Thread(target=run_graphframe) - thread.start() - thread.join() - assert exc is None, f"Exception was raised in thread: {exc}" - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_belief_propagation(spark): - # Create a graphical model g of size 3x3. - g = Graphs(spark).gridIsingModel(3) - # Run Belief Propagation (BP) for 5 iterations. - numIter = 5 - results = BeliefPropagation.runBPwithGraphFrames(g, numIter) - # Check that each belief is a valid probability in [0, 1]. - for row in results.vertices.select("belief").collect(): - belief = row["belief"] - assert 0 <= belief <= 1, f"Expected belief to be probability in [0,1], but found {belief}" - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_graph_friends(spark): - # Construct the graph. - g = Graphs(spark).friends() - # Check that the result is an instance of GraphFrame. - assert isinstance(g, GraphFrame) - - -@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") -def test_graph_grid_ising_model(spark): - # Construct a grid Ising model graph. - n = 3 - g = Graphs(spark).gridIsingModel(n) - # Collect the vertex ids - ids = [v["id"] for v in g.vertices.collect()] - # Verify that every expected vertex id appears. - for i in range(n): - for j in range(n): - assert f"{i},{j}" in ids - - -if __name__ == "__main__": - pytest.main() + assert row.asDict()["count"] == 1, ( + f"Triangle count for vertex {row.id} is not 1" + ) + _ = c.unpersist() + + +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +def test_cycles_finding(spark: SparkSession, args: PregelArguments) -> None: + vertices = spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")], ["id", "attr"] + ) + edges = spark.createDataFrame( + [(1, 2), (2, 3), (3, 1), (1, 4), (2, 5)], ["src", "dst"] + ) + graph = GraphFrame(vertices, edges) + res = graph.detectingCycles( + checkpoint_interval=args.checkpoint_interval, + use_local_checkpoints=args.use_local_checkpoints, + storage_level=args.storage_level, + ) + assert res.count() == 3 + collected = res.sort("id").select("found_cycles").collect() + assert [row[0] for row in collected] == [[1, 2, 3, 1], [2, 3, 1, 2], [3, 1, 2, 3]] + _ = res.unpersist() From b43b52c0ca4638865cdfd3c281e698a2bc3ed40e Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 12:38:16 +0200 Subject: [PATCH 10/17] fixes --- python/graphframes/classic/graphframe.py | 12 +++++++++--- python/graphframes/connect/graphframes_client.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index 3686271f9..6b1d761c3 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -19,10 +19,16 @@ from typing import final from py4j.java_gateway import JavaObject -from pyspark import SparkContext +from pyspark import SparkContext, __version__ from pyspark.sql import SparkSession -from pyspark.sql.classic.column import Column, _to_seq -from pyspark.sql.classic.dataframe import DataFrame + +if __version__.startswith("4"): + from pyspark.sql.classic.column import Column, _to_seq + from pyspark.sql.classic.dataframe import DataFrame +else: + from pyspark.sql.column import Column, _to_seq + from pyspark.sql import DataFrame + from pyspark.storagelevel import StorageLevel from graphframes.classic.utils import storage_level_to_jvm diff --git a/python/graphframes/connect/graphframes_client.py b/python/graphframes/connect/graphframes_client.py index da53c8a50..2e64d7a41 100644 --- a/python/graphframes/connect/graphframes_client.py +++ b/python/graphframes/connect/graphframes_client.py @@ -1061,4 +1061,6 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - return _dataframe_from_plan(TriangleCount(self._vertices, self._edges), self._spark) + return _dataframe_from_plan( + TriangleCount(self._vertices, self._edges, storage_level), self._spark + ) From 7db2bee212a2c934c713c582557f48b8c5198158 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 12:51:47 +0200 Subject: [PATCH 11/17] more tests & fixes --- python/graphframes/graphframe.py | 4 +- python/tests/test_graphframes.py | 98 ++++++++++++++++++++++++++------ 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index e7b6f82b2..710227d3d 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -685,9 +685,9 @@ def validate( ) count_vertices_from_edges = vertices_set_from_edges.count() if count_vertices_from_edges > count_distinct_vertices: - _msg = "Graph is inconsistent: edges has {count_vertices_from_edges} " + _msg = "Graph is inconsistent: edges has {} " _msg += "vertices, but vertices has {} vertices." - raise ValueError(_msg.format(count_distinct_vertices)) + raise ValueError(_msg.format(count_vertices_from_edges, count_distinct_vertices)) combined = vertices_set_from_edges.join(self.vertices, ID, "left_anti") count_of_bad_vertices = combined.count() diff --git a/python/tests/test_graphframes.py b/python/tests/test_graphframes.py index 4da3fc59b..74791e742 100644 --- a/python/tests/test_graphframes.py +++ b/python/tests/test_graphframes.py @@ -58,7 +58,7 @@ class PregelArguments: ] -def test_construction(spark: SparkSession, local_g: GraphFrame): +def test_construction(spark: SparkSession, local_g: GraphFrame) -> None: vertexIDs = [row[0] for row in local_g.vertices.select("id").collect()] assert sorted(vertexIDs) == [1, 2, 3] @@ -83,12 +83,70 @@ def test_construction(spark: SparkSession, local_g: GraphFrame): _ = GraphFrame(v_invalid, e_invalid) -def test_cache(local_g: GraphFrame): +def test_validate(spark: SparkSession) -> None: + good_g = GraphFrame( + spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")]).toDF("id", "attr"), + spark.createDataFrame([(1, 2), (2, 1), (2, 3)]).toDF("src", "dst"), + ) + good_g.validate() # no exception should be thrown + + not_distinct_vertices = GraphFrame( + spark.createDataFrame([(1, "a"), (2, "b"), (3, "c"), (1, "d")]).toDF( + "id", "attr" + ), + spark.createDataFrame([(1, 2), (2, 1), (2, 3)]).toDF("src", "dst"), + ) + with pytest.raises(ValueError): + not_distinct_vertices.validate() + + missing_vertices = GraphFrame( + spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")]).toDF("id", "attr"), + spark.createDataFrame([(1, 2), (2, 1), (2, 3), (1, 4)]).toDF("src", "dst"), + ) + with pytest.raises(ValueError): + missing_vertices.validate() + + +def test_as_undirected(spark: SparkSession) -> None: + # Test without edge attributes + v = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")]).toDF("id", "name") + e = spark.createDataFrame([(1, 2), (2, 3)]).toDF("src", "dst") + g = GraphFrame(v, e) + undirected = g.as_undirected() + + # Check edge count doubled + assert undirected.edges.count() == 2 * g.edges.count() + + # Verify reverse edges exist + edges = undirected.edges.sort("src", "dst").collect() + assert len(edges) == 4 + assert edges[0][0] == 1 + assert edges[0][1] == 2 + assert edges[1][0] == 2 + assert edges[1][1] == 1 + assert edges[2][0] == 2 + assert edges[2][1] == 3 + assert edges[3][0] == 3 + assert edges[3][1] == 2 + + # Test with edge attributes + v2 = spark.createDataFrame([(1, "a"), (2, "b")]).toDF("id", "name") + e2 = spark.createDataFrame([(1, 2, "edge1")]).toDF("src", "dst", "attr") + g2 = GraphFrame(v2, e2) + undirected2 = g2.as_undirected() + + edges2 = undirected2.edges.collect() + assert len(edges2) == 2 + assert any(row[0] == 1 and row[1] == 2 and row[2] == "edge1" for row in edges2) + assert any(row[0] == 2 and row[1] == 1 and row[2] == "edge1" for row in edges2) + + +def test_cache(local_g: GraphFrame) -> None: _ = local_g.cache() _ = local_g.unpersist() -def test_degrees(local_g: GraphFrame): +def test_degrees(local_g: GraphFrame) -> None: outDeg = local_g.outDegrees assert set(outDeg.columns) == {"id", "outDegree"} inDeg = local_g.inDegrees @@ -97,13 +155,13 @@ def test_degrees(local_g: GraphFrame): assert set(deg.columns) == {"id", "degree"} -def test_motif_finding(local_g: GraphFrame): +def test_motif_finding(local_g: GraphFrame) -> None: motifs = local_g.find("(a)-[e]->(b)") assert motifs.count() == 3 assert set(motifs.columns) == {"a", "e", "b"} -def test_filterVertices(local_g: GraphFrame): +def test_filterVertices(local_g: GraphFrame) -> None: conditions = ["id < 3", local_g.vertices.id < 3] expected_v = [(1, "A"), (2, "B")] expected_e = [(1, 2, "love"), (2, 1, "hate")] @@ -117,7 +175,7 @@ def test_filterVertices(local_g: GraphFrame): assert set(e2) == set(expected_e) -def test_filterEdges(local_g: GraphFrame): +def test_filterEdges(local_g: GraphFrame) -> None: conditions = ["dst > 2", local_g.edges.dst > 2] expected_v = [(1, "A"), (2, "B"), (3, "C")] expected_e = [(2, 3, "follow")] @@ -131,7 +189,7 @@ def test_filterEdges(local_g: GraphFrame): assert set(e2) == set(expected_e) -def test_dropIsolatedVertices(local_g: GraphFrame): +def test_dropIsolatedVertices(local_g: GraphFrame) -> None: g2 = local_g.filterEdges("dst > 2").dropIsolatedVertices() v2 = g2.vertices.select("id", "name").collect() e2 = g2.edges.select("src", "dst", "action").collect() @@ -143,7 +201,7 @@ def test_dropIsolatedVertices(local_g: GraphFrame): assert set(e2) == set(expected_e) -def test_bfs(local_g: GraphFrame): +def test_bfs(local_g: GraphFrame) -> None: paths = local_g.bfs("name='A'", "name='C'") assert paths is not None assert paths.count() == 1 @@ -159,7 +217,7 @@ def test_bfs(local_g: GraphFrame): assert paths3.count() == 0 -def test_power_iteration_clustering(spark: SparkSession): +def test_power_iteration_clustering(spark: SparkSession) -> None: vertices = [ (1, 0, 0.5), (2, 0, 0.5), @@ -190,7 +248,8 @@ def test_power_iteration_clustering(spark: SparkSession): _ = clusters_df.unpersist() -def test_page_rank(spark: SparkSession): +@pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) +def test_page_rank(spark: SparkSession, args: PregelArguments) -> None: edges = spark.createDataFrame( [ [0, 1], @@ -239,7 +298,7 @@ def test_page_rank(spark: SparkSession): @pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) -def test_pregel_early_stopping(spark: SparkSession, args: PregelArguments): +def test_pregel_early_stopping(spark: SparkSession, args: PregelArguments) -> None: edges = spark.createDataFrame( [ [0, 1], @@ -266,6 +325,9 @@ def test_pregel_early_stopping(spark: SparkSession, args: PregelArguments): pregel = graph.pregel ranks = ( graph.pregel.setMaxIter(5) + .setUseLocalCheckpoints(args.use_local_checkpoints) + .setIntermediateStorageLevel(args.storage_level) + .setCheckpointInterval(args.checkpoint_interval) .setEarlyStopping(True) .setUseLocalCheckpoints(args.use_local_checkpoints) .setIntermediateStorageLevel(args.storage_level) @@ -291,14 +353,14 @@ def test_pregel_early_stopping(spark: SparkSession, args: PregelArguments): _ = ranks.unpersist() -def _hasCols(graph: GraphFrame, vcols: list[str] = [], ecols: list[str] = []): +def _hasCols(graph: GraphFrame, vcols: list[str] = [], ecols: list[str] = []) -> None: for c in vcols: assert c in graph.vertices.columns, f"Vertex DataFrame missing column: {c}" for c in ecols: assert c in graph.edges.columns, f"Edge DataFrame missing column: {c}" -def _df_hasCols(df: DataFrame, vcols: list[str] = []): +def _df_hasCols(df: DataFrame, vcols: list[str] = []) -> None: for c in vcols: assert c in df.columns, f"DataFrame missing column: {c}" @@ -311,7 +373,7 @@ def _df_hasCols(df: DataFrame, vcols: list[str] = []): ) def test_connected_components( spark: SparkSession, args: PregelArguments, cc_args: tuple[int, bool] -): +) -> None: v = spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) e = spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter("src > 10") v = spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) @@ -338,7 +400,7 @@ def test_connected_components( ) def test_connected_components2( spark: SparkSession, args: PregelArguments, cc_args: tuple[int, bool] -): +) -> None: v = spark.createDataFrame([(0, "a0", "b0"), (1, "a1", "b1")], ["id", "A", "B"]) e = spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) g = GraphFrame(v, e) @@ -356,7 +418,7 @@ def test_connected_components2( @pytest.mark.parametrize("args", PREGEL_ARGUMENTS, ids=PREGEL_IDS) -def test_shortest_paths(spark: SparkSession, args: PregelArguments): +def test_shortest_paths(spark: SparkSession, args: PregelArguments) -> None: edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] # Create bidirectional edges. all_edges = [z for (a, b) in edges for z in [(a, b), (b, a)]] @@ -377,7 +439,7 @@ def test_shortest_paths(spark: SparkSession, args: PregelArguments): _ = v2.unpersist() -def test_strongly_connected_components(spark: SparkSession): +def test_strongly_connected_components(spark: SparkSession) -> None: # Simple island test vertices = spark.createDataFrame([(i,) for i in range(1, 6)], ["id"]) edges = spark.createDataFrame([(7, 8)], ["src", "dst"]) @@ -391,7 +453,7 @@ def test_strongly_connected_components(spark: SparkSession): @pytest.mark.parametrize("storage_level", STORAGE_LEVELS, ids=STORAGE_LEVELS_IDS) -def test_triangle_counts(spark: SparkSession, storage_level: StorageLevel): +def test_triangle_counts(spark: SparkSession, storage_level: StorageLevel) -> None: edges = spark.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) vertices = spark.createDataFrame([(0,), (1,), (2,)], ["id"]) g = GraphFrame(vertices, edges) From 754ceba54b60b032427985422bb481699e07c5ef Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 12:55:19 +0200 Subject: [PATCH 12/17] update docstrings --- python/graphframes/graphframe.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 710227d3d..c2464cc4c 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -428,12 +428,21 @@ def connectedComponents( See Scala documentation for more details. :param algorithm: connected components algorithm to use (default: "graphframes") - Supported algorithms are "graphframes" and "graphx". + Supported algorithms are "graphframes" and "graphx". :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) :param broadcastThreshold: broadcast threshold in propagating component assignments - (default: 1000000) + (default: 1000000). Passing -1 disable manual broadcasting and + allows AQE to handle skewed joins. This mode is much faster + and is recommended to use. Default value may be changed to -1 + in the future versions of GraphFrames. :param useLabelsAsComponents: if True, uses the vertex labels as components, otherwise will - use longs + use longs + :param use_local_checkpoints: should local checkpoints be used, default false; + local checkpoints are faster and does not require to set + a persistent checkpointDir; from the other side, local + checkpoints are less reliable and require executors to have + big enough local disks. + :param storage_level: storage level for both intermediate and final dataframes. :return: DataFrame with new vertices column "component" """ From 62491a5858679fc4be27654f8dbc4ac2f2e14cb3 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 12:56:31 +0200 Subject: [PATCH 13/17] increase driver memory --- python/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 91d83a96d..031632dbb 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -84,7 +84,7 @@ def spark(): .config("spark.sql.shuffle.partitions", 4) .config("spark.checkpoint.dir", tmp_dir) .config("spark.jars", f"{core_jar},{connect_jar},{graphx_jar}") - .config("spark.driver.memory", "4g") + .config("spark.driver.memory", "6g") ) if spark_major_version == "3": From 78325f6b089c12ec1b4d76e89f652155a37d447f Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 17:12:49 +0200 Subject: [PATCH 14/17] fix 3.5.x spark-connect problem --- python/graphframes/graphframe.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index c2464cc4c..c640b2f2b 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -718,14 +718,22 @@ def as_undirected(self) -> "GraphFrame": edge_attr_columns = [c for c in self.edges.columns if c not in [SRC, DST]] # Create the undirected edges by duplicating each edge in both directions - forward_edges = self.edges.select( - F.col(SRC), F.col(DST), F.struct(*edge_attr_columns).alias(EDGE) - ) - backward_edges = self.edges.select( - F.col(DST).alias(SRC), - F.col(SRC).alias(DST), - F.struct(*edge_attr_columns).alias(EDGE), - ) + + # 3.5.x problem: selecting empty struct fails on spark connect + # TODO: remove after removing 3.5.x + + if edge_attr_columns: + forward_edges = self.edges.select( + F.col(SRC), F.col(DST), F.struct(*edge_attr_columns).alias(EDGE) + ) + backward_edges = self.edges.select( + F.col(DST).alias(SRC), + F.col(SRC).alias(DST), + F.struct(*edge_attr_columns).alias(EDGE), + ) + else: + forward_edges = self.edges.select(F.col(SRC), F.col(DST)) + backward_edges = self.edges.select(F.col(DST).alias(SRC), F.col(SRC).alias(DST)) new_edges = forward_edges.union(backward_edges).select(SRC, DST, EDGE) # Preserve additional edge attributes From 3a6fb6ff29b00c1baf218aede3ea3897b7b880a7 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 18:53:35 +0200 Subject: [PATCH 15/17] fix pytest --- python/graphframes/graphframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index c640b2f2b..d2cae7b9e 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -731,10 +731,11 @@ def as_undirected(self) -> "GraphFrame": F.col(SRC).alias(DST), F.struct(*edge_attr_columns).alias(EDGE), ) + new_edges = forward_edges.union(backward_edges).select(SRC, DST, EDGE) else: forward_edges = self.edges.select(F.col(SRC), F.col(DST)) backward_edges = self.edges.select(F.col(DST).alias(SRC), F.col(SRC).alias(DST)) - new_edges = forward_edges.union(backward_edges).select(SRC, DST, EDGE) + new_edges = forward_edges.union(backward_edges).select(SRC, DST) # Preserve additional edge attributes edge_columns = [F.col(EDGE).getField(c).alias(c) for c in edge_attr_columns] From 39cbf617b621164066d82ddfb3e601dcf39dbc23 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 19:29:24 +0200 Subject: [PATCH 16/17] final fixes --- python/graphframes/classic/graphframe.py | 8 +-- python/graphframes/lib/aggregate_messages.py | 8 +-- python/tests/conftest.py | 11 ++- python/tests/test_graphframes.py | 70 ++++++++++++++++++++ 4 files changed, 86 insertions(+), 11 deletions(-) diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index 6b1d761c3..ab9618906 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -188,14 +188,14 @@ def aggregateMessages( if len(aggCol) == 1: if isinstance(aggCol[0], Column): - jdf = builder.aggCol(aggCol[0]._jc) + jdf = builder.agg(aggCol[0]._jc) elif isinstance(aggCol[0], str): - jdf = builder.aggCol(aggCol[0]) + jdf = builder.agg(aggCol[0]) elif len(aggCol) > 1: if all(isinstance(x, Column) for x in aggCol): - jdf = builder.aggCol(aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol])) + jdf = builder.agg(aggCol[0]._jc, _to_seq(self._sc, [x._jc for x in aggCol])) elif all(isinstance(x, str) for x in aggCol): - jdf = builder.aggCol(aggCol[0], _to_seq(self._sc, aggCol[1:])) + jdf = builder.agg(aggCol[0], _to_seq(self._sc, aggCol[1:])) else: raise TypeError( "Multiple agg cols should all be `Column` or `str`, not a mix of them." diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py index fce2cc478..932856636 100644 --- a/python/graphframes/lib/aggregate_messages.py +++ b/python/graphframes/lib/aggregate_messages.py @@ -40,21 +40,21 @@ class AggregateMessages: """Collection of utilities usable with :meth:`graphframes.GraphFrame.aggregateMessages()`.""" @_ClassProperty - def src() -> Column: + def src(cls) -> Column: """Reference for source column, used for specifying messages.""" return F.col("src") @_ClassProperty - def dst() -> Column: + def dst(cls) -> Column: """Reference for destination column, used for specifying messages.""" return F.col("dst") @_ClassProperty - def edge() -> Column: + def edge(cls) -> Column: """Reference for edge column, used for specifying messages.""" return F.col("edge") @_ClassProperty - def msg() -> Column: + def msg(cls) -> Column: """Reference for message column, used for specifying aggregation function.""" return F.col("MSG") diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 031632dbb..782664186 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -5,6 +5,7 @@ import tempfile import warnings +from py4j.java_gateway import JavaObject import pytest from pyspark.sql import SparkSession from pyspark.version import __version__ @@ -108,7 +109,7 @@ def spark(): @pytest.fixture(scope="module") -def local_g(spark): +def local_g(spark: SparkSession): localVertices = [(1, "A"), (2, "B"), (3, "C")] localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] v = spark.createDataFrame(localVertices, ["id", "name"]) @@ -117,11 +118,15 @@ def local_g(spark): @pytest.fixture(scope="module") -def examples(spark): +def examples(spark: SparkSession): if is_remote(): # TODO: We should update tests to be able to run all of them on Spark Connect # At the moment the problem is that examples API is py4j based. yield None else: japi = _java_api(spark._sc) - yield japi.examples() + assert japi is not None + examples = japi.examples() + assert examples is not None + assert isinstance(examples, JavaObject) + yield examples diff --git a/python/tests/test_graphframes.py b/python/tests/test_graphframes.py index 74791e742..2e9156174 100644 --- a/python/tests/test_graphframes.py +++ b/python/tests/test_graphframes.py @@ -25,6 +25,8 @@ from graphframes.examples import BeliefPropagation, Graphs from graphframes.graphframe import GraphFrame +from pyspark.sql import is_remote + @dataclass class PregelArguments: @@ -483,3 +485,71 @@ def test_cycles_finding(spark: SparkSession, args: PregelArguments) -> None: collected = res.sort("id").select("found_cycles").collect() assert [row[0] for row in collected] == [[1, 2, 3, 1], [2, 3, 1, 2], [3, 1, 2, 3]] _ = res.unpersist() + + +@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") +def test_svd_plus_plus(examples, spark: SparkSession): + g = _from_java_gf(getattr(examples, "ALSSyntheticData")(), spark) + (v2, cost) = g.svdPlusPlus() + _df_hasCols(v2, vcols=["id", "column1", "column2", "column3", "column4"]) + + +@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") +def test_mutithreaded_sparksession_usage(spark: SparkSession): + # Test that the GraphFrame API works correctly from multiple threads. + localVertices = [(1, "A"), (2, "B"), (3, "C")] + localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] + v = spark.createDataFrame(localVertices, ["id", "name"]) + e = spark.createDataFrame(localEdges, ["src", "dst", "action"]) + + exc = None + + def run_graphframe() -> None: + nonlocal exc + try: + GraphFrame(v, e) + except Exception as _e: + exc = _e + + import threading + + thread = threading.Thread(target=run_graphframe) + thread.start() + thread.join() + assert exc is None, f"Exception was raised in thread: {exc}" + + +@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") +def test_belief_propagation(spark: SparkSession): + # Create a graphical model g of size 3x3. + g = Graphs(spark).gridIsingModel(3) + # Run Belief Propagation (BP) for 5 iterations. + numIter = 5 + results = BeliefPropagation.runBPwithGraphFrames(g, numIter) + # Check that each belief is a valid probability in [0, 1]. + for row in results.vertices.select("belief").collect(): + belief = row["belief"] + assert 0 <= belief <= 1, ( + f"Expected belief to be probability in [0,1], but found {belief}" + ) + + +@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") +def test_graph_friends(spark: SparkSession): + # Construct the graph. + g = Graphs(spark).friends() + # Check that the result is an instance of GraphFrame. + assert isinstance(g, GraphFrame) + + +@pytest.mark.skipif(is_remote(), reason="DISABLE FOR CONNECT") +def test_graph_grid_ising_model(spark: SparkSession): + # Construct a grid Ising model graph. + n = 3 + g = Graphs(spark).gridIsingModel(n) + # Collect the vertex ids + ids = [v["id"] for v in g.vertices.collect()] + # Verify that every expected vertex id appears. + for i in range(n): + for j in range(n): + assert f"{i},{j}" in ids From 366dd1e23d6441a33d926c45f8f8edbfb03f2965 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 5 Oct 2025 19:56:26 +0200 Subject: [PATCH 17/17] fix 3.5.x --- python/tests/test_graphframes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_graphframes.py b/python/tests/test_graphframes.py index 2e9156174..f5aa0eace 100644 --- a/python/tests/test_graphframes.py +++ b/python/tests/test_graphframes.py @@ -25,7 +25,7 @@ from graphframes.examples import BeliefPropagation, Graphs from graphframes.graphframe import GraphFrame -from pyspark.sql import is_remote +from pyspark.sql.utils import is_remote @dataclass