10BC0 feat: reduce memory consumption of cycles detection by SemyonSinchenko · Pull Request #731 · graphframes/graphframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions core/src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,7 @@ class GraphFrame private (
* large-scale sparse graphs." Proceedings of Simpósio Brasileiro de Pesquisa Operacional
* (SBPO’15) (2015): 1-11.
*
* Returns a DataFrame with ID and cycles, ID are not unique if there are multiple cycles
* starting from this ID. For the case of cycle 1 -> 2 -> 3 -> 1 all the vertices will have the
* same cycle! E.g.: 1 -> [1, 2, 3, 1] 2 -> [2, 3, 1, 2] 3 -> [3, 1, 2, 3]
*
* Deduplication of cycles should be done by the user!
* Returns a DataFrame with unque cycles.
*
* @return
* an instance of DetectingCycles initialized with the current context
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/graphframes/lib/DetectingCycles.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.graphframes.lib

import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.*
import org.apache.spark.sql.types.ArrayType
Expand Down Expand Up @@ -30,7 +31,6 @@ class DetectingCycles private[graphframes] (private val graph: GraphFrame)
filter(col(foundSeqCol), x => size(x) > lit(0)).alias(foundSeqCol))
.filter(size(col(foundSeqCol)) > lit(0))
.select(
col(GraphFrame.ID),
// from vid -> [[cycle1, cycle2, ...]]
// to vid -> [cycle1], vid -> [cycle2], ...
explode(col(foundSeqCol)).alias(foundSeqCol))
Expand Down Expand Up @@ -62,7 +62,10 @@ object DetectingCycles {
// Each vertex stores all the found cycles
val foundSequences = array().cast(ArrayType(ArrayType(vertexDT)))
// Message is simply stored sequences
val sentMessages = when(size(Pregel.src(storedSeqCol)) =!= lit(0), Pregel.src(storedSeqCol))
// Send only sequences if the starting vertex of them is less than the destination
val sentMessages = when(
size(Pregel.src(storedSeqCol)) =!= lit(0),
filter(Pregel.src(storedSeqCol), (x: Column) => x.getItem(0) <= Pregel.dst(GraphFrame.ID)))
.otherwise(lit(null).cast(ArrayType(ArrayType(vertexDT))))
// If the sequence contains the current vertex ID somewhere in the middle, it is
// a previously detected cycle and a sequence should be discarded.
Expand Down
15 changes: 3 additions & 12 deletions core/src/test/scala/org/graphframes/lib/DetectingCyclesSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@ class DetectingCyclesSuite extends SparkFunSuite with GraphFrameTestSparkContext
.createDataFrame(Seq((1L, 2L), (2L, 3L), (3L, 1L), (1L, 4L), (2L, 5L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
assert(res.count() == 3)
assert(res.count() == 1)
@nowarn val collected =
res
.sort(GraphFrame.ID)
.select(DetectingCycles.foundSeqCol)
.collect()
.map(r => r.getAs[mutable.WrappedArray[Long]](0))

assert(collected(0) == Seq(1, 2, 3, 1))
assert(collected(1) == Seq(2, 3, 1, 2))
assert(collected(2) == Seq(3, 1, 2, 3))
res.unpersist()
}

Expand All @@ -53,20 +49,15 @@ class DetectingCyclesSuite extends SparkFunSuite with GraphFrameTestSparkContext
.createDataFrame(Seq((1L, 2L), (2L, 1L), (1L, 3L), (3L, 1L), (2L, 5L), (5L, 1L)))
.toDF("src", "dst"))
val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
assert(res.count() == 7)
assert(res.count() == 3)
@nowarn val collected =
res
.sort(GraphFrame.ID, DetectingCycles.foundSeqCol)
.select(DetectingCycles.foundSeqCol)
.sort(DetectingCycles.foundSeqCol)
.collect()
.map(r => r.getAs[mutable.WrappedArray[Long]](0))
assert(collected(0) == Seq(1, 2, 1))
assert(collected(1) == Seq(1, 2, 5, 1))
assert(collected(2) == Seq(1, 3, 1))
assert(collected(3) == Seq(2, 1, 2))
assert(collected(4) == Seq(2, 5, 1, 2))
assert(collected(5) == Seq(3, 1, 3))
assert(collected(6) == Seq(5, 1, 2, 5))
res.unpersist()
}
}
20 changes: 6 additions & 14 deletions docs/src/04-user-guide/05-traversals.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,10 @@ val res = graph.detectingCycles.setUseLocalCheckpoints(true).run()
res.show(false)

// Output:
// +----+--------------+
// | id | found_cycles |
// +----+--------------+
// |1 |[1, 3, 1] |
// |1 |[1, 2, 1] |
// |1 |[1, 2, 5, 1] |
// |2 |[2, 1, 2] |
// |2 |[2, 5, 1, 2] |
// |3 |[3, 1, 3] |
// |5 |[5, 1, 2, 5] |
// +----+--------------+
// +--------------+
// | found_cycles |
// +--------------+
// |[1, 3, 1] |
// |[1, 2, 1] |
// |[1, 2, 5, 1] |
```

**WARNING:** This algorithm returns all the cycles, and users should handle deduplication of \[1, 2, 1\] and \[2, 1, 2\] (
that is the same cycle)!
6 changes: 1 addition & 5 deletions python/graphframes/graphframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,7 @@ def detectingCycles(
large-scale sparse graphs." Proceedings of Simpósio Brasileiro de Pesquisa Operacional
(SBPO’15) (2015): 1-11.

Returns a DataFrame with ID and cycles, ID are not unique if there are multiple cycles
starting from this ID. For the case of cycle 1 -> 2 -> 3 -> 1 all the vertices will have the
same cycle! E.g.: 1 -> [1, 2, 3, 1] 2 -> [2, 3, 1, 2] 3 -> [3, 1, 2, 3]

Deduplication of cycles should be done by the user!
Returns a DataFrame with unique cycles.

:param checkpoint_interval: Pregel checkpoint interval, default is 2
:param use_local_checkpoints: should local checkpoints be used instead of checkpointDir
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_graphframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def test_cycles_finding(spark: SparkSession, args: PregelArguments) -> None:
use_local_checkpoints=args.use_local_checkpoints,
storage_level=args.storage_level,
)
assert res.count() == 3
assert res.count() == 1
collected = res.sort("id").select("found_cycles").collect()
assert [row[0] for row in collected] == [[1, 2, 3, 1], [2, 3, 1, 2], [3, 1, 2, 3]]
assert collected[0][0] == [1, 2, 3, 1]
_ = res.unpersist()


Expand Down
0