diff --git a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 11ef58b1c..6d90ed346 100644 --- a/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -185,7 +185,12 @@ object ConnectedComponents extends Logging { private def runGraphX(graph: GraphFrame, maxIter: Int): DataFrame = { val components = graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX, maxIter) - GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices + val res = + GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices + res.persist(StorageLevel.MEMORY_AND_DISK_SER) + res.count() + components.unpersist() + res } private def run( diff --git a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala index f17b82420..ef1e123d8 100644 --- a/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/core/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.WithAlgorithmChoice import org.graphframes.WithCheckpointInterval @@ -67,7 +68,11 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame) private object LabelPropagation { private def runInGraphX(graph: GraphFrame, maxIter: Int): DataFrame = { val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter) - GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices + val res = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices + res.persist(StorageLevel.MEMORY_AND_DISK_SER) + res.count() + gx.unpersist() + res } private def keyWithMaxValue(column: Column): Column = { diff --git a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala index cc972d8fb..bd4df94e4 100644 --- a/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala +++ b/core/src/main/scala/org/graphframes/lib/ShortestPaths.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.functions.transform_values import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.MapType +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.GraphFrame.quote import org.graphframes.GraphFramesUnreachableException @@ -109,7 +110,11 @@ private object ShortestPaths extends Logging { transform_keys(col(DISTANCE_ID), (longId: Column, _) => longIdToLandmarkColumn(longId)) } val cols = graph.vertices.columns.map(quote).map(col) :+ distanceCol.as(DISTANCE_ID) - g.vertices.select(cols.toSeq: _*) + val res = g.vertices.select(cols.toSeq: _*) + res.persist(StorageLevel.MEMORY_AND_DISK_SER) + res.count() + gx.unpersist() + res } private def runInGraphFrames( diff --git a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index 0f52ae7ca..2ee678f74 100644 --- a/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/core/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -19,6 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphframes.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame +import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame import org.graphframes.WithMaxIter @@ -42,7 +43,11 @@ class StronglyConnectedComponents private[graphframes] (private val graph: Graph private object StronglyConnectedComponents { private def run(graph: GraphFrame, numIter: Int): DataFrame = { val gx = graphxlib.StronglyConnectedComponents.run(graph.cachedTopologyGraphX, numIter) - GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(COMPONENT_ID)).vertices + val res = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(COMPONENT_ID)).vertices + res.persist(StorageLevel.MEMORY_AND_DISK_SER) + res.count() + gx.unpersist() + res } private[graphframes] val COMPONENT_ID = "component" diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Pregel.scala index d28e55370..d77fd13ed 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/Pregel.scala @@ -135,8 +135,9 @@ object Pregel extends Logging { // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + // It is absolutely enough to checkpoint only graph itself. val messageCheckpointer = - new PeriodicRDDCheckpointer[(VertexId, A)](checkpointInterval, graph.vertices.sparkContext) + new PeriodicRDDCheckpointer[(VertexId, A)](-1, graph.vertices.sparkContext) messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var isActiveMessagesNonEmpty = !messages.isEmpty() diff --git a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala index 7f8d79ca8..856cd112b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphframes/graphx/impl/GraphImpl.scala @@ -200,8 +200,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( mergeMsg: (A, A) => A, tripletFields: TripletFields, activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { - - vertices.cache() // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) @@ -267,7 +265,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left // null if not if (eq != null) { - vertices.cache() // updateF preserves type, so we can use incremental replication val newVerts = vertices.leftJoin(other)(updateF).cache() val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) @@ -337,14 +334,10 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { - - vertices.cache() - // Convert the vertex partitions in edges to the correct type val newEdges = edges .asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((_, part) => part.withoutVertexAttributes[VD]()) - .cache() GraphImpl.fromExistingRDDs(vertices, newEdges) } @@ -369,7 +362,7 @@ object GraphImpl { defaultVertexAttr: VD, edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { - val edgesCached = edges.withTargetStorageLevel(edgeStorageLevel).cache() + val edgesCached = edges.withTargetStorageLevel(edgeStorageLevel) val vertices = VertexRDD .fromEdges(edgesCached, edgesCached.partitions.length, defaultVertexAttr)