diff --git a/.gitignore b/.gitignore index 286169766..c9e2bda9e 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,7 @@ spark-* # Zed .zed + +# Emacs +.dir-locals.el +*~ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd1b12d03..8a1c2099e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,15 +21,14 @@ repos: - id: scalafmt name: scalafmt - entry: build/sbt scalafmtCheckAll + entry: build/sbt scalafmtAll language: system types: [scala] pass_filenames: false - id: scalafix name: scalafix - entry: build/sbt "scalafixAll --check" + entry: build/sbt scalafixAll language: system types: [scala] pass_filenames: false - diff --git a/NOTICE b/NOTICE index 01612246c..1a8074a45 100644 --- a/NOTICE +++ b/NOTICE @@ -8,3 +8,17 @@ Copyright 2014-2025 The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). + +This product includes wiki-Vote dataset from SNAP collection for testing purposes only. +Citation: + J. Leskovec, D. Huttenlocher, J. Kleinberg. Signed Networks in Social Media. CHI 2010. + J. Leskovec, D. Huttenlocher, J. Kleinberg. Predicting Positive and Negative Links in Online Social Networks. WWW 2010. + +SNAP Datasets: +@misc{snapnets, + author = {Jure Leskovec and Andrej Krevl}, + title = {{SNAP Datasets}: {Stanford} Large Network Dataset Collection}, + howpublished = {\url{http://snap.stanford.edu/data}}, + month = jun, + year = 2014 +} diff --git a/build.sbt b/build.sbt index 13561823f..0f19b9154 100644 --- a/build.sbt +++ b/build.sbt @@ -108,7 +108,6 @@ lazy val commonSetting = Seq( ScalacOptions.warnUnusedImports, ScalacOptions.warnUnusedParams, ScalacOptions.warnUnusedPrivates, - ScalacOptions.warnUnusedNoWarn, ScalacOptions.source3, ScalacOptions.fatalWarnings), tpolecatExcludeOptions ++= Set(ScalacOptions.warnNonUnitStatement), diff --git a/core/src/main/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAgg.scala b/core/src/main/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAgg.scala new file mode 100644 index 000000000..8a587ffa2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAgg.scala @@ -0,0 +1,118 @@ +package org.apache.spark.sql.graphframes.expressions + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator + +import scala.reflect.runtime.universe.TypeTag + +import collection.mutable.ArrayBuffer + +case class Reservoir[T](seq: ArrayBuffer[T], elements: Int) extends Serializable + +case class ReservoirSamplingAgg[T: TypeTag](size: Int) + extends Aggregator[T, Reservoir[T], Seq[T]] + with Serializable { + + override def zero: Reservoir[T] = Reservoir[T](ArrayBuffer.empty, 0) + + override def reduce(b: Reservoir[T], a: T): Reservoir[T] = { + if (b.seq.size < size) { + Reservoir(b.seq += a, b.elements + 1) + } else { + val j = java.util.concurrent.ThreadLocalRandom.current().nextInt(b.elements + 1) + if (j < size) { + b.seq(j) = a + } + Reservoir(b.seq, b.elements + 1) + } + } + + private def mergeFull(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] = { + val total_cnt = left.elements + right.elements + val rng = java.util.concurrent.ThreadLocalRandom.current() + val pLeft = left.elements.toDouble / total_cnt.toDouble + + var newSeq = ArrayBuffer.empty[T] + val leftCloned = left.seq.clone() + val rightCloned = right.seq.clone() + for (_ <- (1 to size)) { + if (rng.nextDouble() <= pLeft) { + newSeq = newSeq += leftCloned.remove(rng.nextInt(leftCloned.size)) + } else { + newSeq = newSeq += rightCloned.remove(rng.nextInt(rightCloned.size)) + } + } + + Reservoir(newSeq, total_cnt) + } + + private def mergeTwoPartial(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] = { + val total_cnt = left.elements + right.elements + val rng = java.util.concurrent.ThreadLocalRandom.current() + if (total_cnt <= size) { + Reservoir(left.seq ++ right.seq, total_cnt) + } else { + val currElements = left.seq ++ right.seq.slice(0, size - left.elements) + var currSize = size + 1 + + for (i <- ((size - left.elements) to right.elements)) { + val j = rng.nextInt(currSize) + if (j < size) { + currElements(j) = right.seq(i) + } + currSize += 1 + } + + Reservoir(currElements, currSize) + } + } + + private def mergePartialRight(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] = { + val total_cnt = left.elements + right.elements + val pLeft = left.elements.toDouble / total_cnt.toDouble + val currElements = ArrayBuffer.empty[T] + val rng = java.util.concurrent.ThreadLocalRandom.current() + + // TODO: I'm nor actually sure + // that we need to clone it. + // Does Spark handle it by itself? + // Is there any chance the link shared between tasks? + val clonedLeft = left.seq.clone() + val clonedRight = right.seq.clone() + for (_ <- (1 to size)) { + if ((clonedRight.isEmpty) || (rng.nextDouble() <= pLeft)) { + val idx = rng.nextInt(clonedLeft.size) + currElements += clonedLeft.remove(idx) + } else { + val idx = rng.nextInt(clonedRight.size) + currElements += clonedRight.remove(idx) + } + } + + Reservoir(currElements, total_cnt) + } + + override def merge(b1: Reservoir[T], b2: Reservoir[T]): Reservoir[T] = { + val (left, right) = if (b1.seq.size > b2.seq.size) { + (b1, b2) + } else { + (b2, b1) + } + + if (left.elements < size) { + mergeTwoPartial(left, right) + } else if (right.elements < size) { + mergePartialRight(left, right) + } else { + mergeFull(left, right) + } + } + + override def finish(reduction: Reservoir[T]): Seq[T] = reduction.seq.toSeq + + override def bufferEncoder: Encoder[Reservoir[T]] = Encoders.product + + override def outputEncoder: Encoder[Seq[T]] = ExpressionEncoder[Seq[T]]() +} diff --git a/core/src/main/scala/org/graphframes/embeddings/Hash2Vec.scala b/core/src/main/scala/org/graphframes/embeddings/Hash2Vec.scala new file mode 100644 index 000000000..2a157e9dc --- /dev/null +++ b/core/src/main/scala/org/graphframes/embeddings/Hash2Vec.scala @@ -0,0 +1,251 @@ +package org.graphframes.embeddings + +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.stat.Summarizer +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.ByteType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.ShortType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.hash.Murmur3_x86_32.* +import org.apache.spark.unsafe.types.UTF8String +import org.graphframes.GraphFramesUnsupportedVertexTypeException +import org.graphframes.rw.RandomWalkBase + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters.* +import scala.reflect.ClassTag + +/** + * Implementation of Hash2Vec, an efficient word embedding technique using feature hashing. Based + * on: Argerich, Luis, Joaquín Torré Zaffaroni, and Matías J. Cano. "Hash2vec, feature hashing for + * word embeddings." arXiv preprint arXiv:1608.08940 (2016). + * + * Produces embeddings for elements in sequences using a hash-based approach to avoid storing a + * vocabulary. Uses MurmurHash3 for hashing elements to embedding indices and signs. + * + * Output DataFrame has columns "id" (element identifier, same type as sequence elements) and + * "vector" (dense vector of doubles, summed across all occurrences). + * + * Tradeoffs: Higher numPartitions reduces local state and memory per partition but increases + * aggregation and merging overhead across partitions. Larger embeddingsDim provides richer + * representations but consumes more memory. Seeds control hashing for reproducibility. + */ +class Hash2Vec extends Serializable { + private def decayGaussian(d: Int, sigma: Double): Double = { + math.exp(-(d * d) / (sigma * sigma)) + } + private val possibleDecayFunctions: Seq[String] = Seq("gaussian", "constant") + + private var contextSize: Int = 5 + private var numPartitions: Int = 5 + private var embeddingsDim: Int = 256 + private var sequenceCol: String = RandomWalkBase.rwColName + private var decayFunction: String = "gaussian" + private var gaussianSigma: Double = 1.0 + private var hashingSeed: Int = 42 + private var signHashingSeed: Int = 18 + + /** + * Sets the context window size around each element to consider during training. Larger values + * incorporate more distant elements but increase computation time. Default: 5. + */ + def setContextSize(value: Int): this.type = { + contextSize = value + this + } + + /** + * Sets the number of partitions for RDDs to parallelize computation. More partitions distribute + * workload and reduce memory per partition but complicate merging across partitions. Default: + * 5. + */ + def setNumPartitions(value: Int): this.type = { + numPartitions = value + this + } + + /** + * Sets the dimensionality of the dense embedding vectors. Larger dimensions allow richer + * representations but require more memory. Corresponds to the hash table size. Default: 256. + */ + def setEmbeddingsDim(value: Int): this.type = { + embeddingsDim = value + this + } + + /** + * Sets the column name containing sequences of elements (as arrays). Default: "random_walk". + */ + def setSequenceCol(value: String): this.type = { + sequenceCol = value + this + } + + /** + * Sets the decay function used to weight context elements by distance. Supported values: + * "gaussian", "constant". Default: "gaussian". + */ + def setDecayFunction(value: String): this.type = { + val sep = ", " + require( + possibleDecayFunctions.contains(value), + s"supported functions: ${possibleDecayFunctions.mkString(sep)}") + decayFunction = value + this + } + + /** + * Sets the sigma parameter for Gaussian decay weighting. Smaller values decay weights faster + * with distance. Default: 1.0. + */ + def setGaussianSigma(value: Double): this.type = { + gaussianSigma = value + this + } + + /** + * Sets the seed for hashing elements to embedding indices. Used for reproducibility of + * embeddings. Default: 42. + */ + def setHashingSeed(value: Int): this.type = { + hashingSeed = value + this + } + + /** + * Sets the seed for hashing elements to determine the sign of contributions. Used for + * reproducibility of embeddings. Default: 18. + */ + def setSignHashSeed(value: Int): this.type = { + signHashingSeed = value + this + } + + private def nonNegativeMod(x: Int, mod: Int): Int = { + val rawMod = x % mod + rawMod + (if (rawMod < 0) mod else 0) + } + + private var valueHash: (Any) => Int = _ + private var signHash: (Any) => Int = _ + private var weightFunction: (Int) => Double = _ + + private def hashFunc(term: Any, seed: Int): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b.toInt, seed) + case s: Short => hashInt(s.toInt, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => + throw new GraphFramesUnsupportedVertexTypeException( + "Hashing2vec with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } + + /** + * Runs the Hash2Vec algorithm on the input DataFrame containing sequences. The specified + * sequenceCol must contain arrays of elements (string or numeric). Produces a DataFrame with + * "id" (element ID, same type as elements) and "vector" (embedding vector, VectorType). + * Embeddings are summed across all partitions and occurrences. + */ + def run(data: DataFrame): DataFrame = { + val spark = data.sparkSession + require(data.schema(sequenceCol).dataType.isInstanceOf[ArrayType], "sequence should be array") + val elDataType = data.schema(sequenceCol).dataType.asInstanceOf[ArrayType].elementType + + weightFunction = decayFunction match { + case "gaussian" => (d: Int) => decayGaussian(d, gaussianSigma) + case "constant" => (_: Int) => 1.0 + case _ => throw new RuntimeException(s"unsupported decay functions $decayFunction") + } + + valueHash = (el: Any) => nonNegativeMod(hashFunc(el, hashingSeed), embeddingsDim) + signHash = (el: Any) => nonNegativeMod(hashFunc(el, signHashingSeed), 2) + + val (rowRDD, schema) = elDataType match { + case _: StringType => + ( + runTyped[String](data).map(f => Row(f._1, Vectors.dense(f._2))), + StructType(Seq(StructField("id", StringType), StructField("vector", VectorType)))) + case _: ByteType => + ( + runTyped[Byte](data).map(f => Row(f._1, Vectors.dense(f._2))), + StructType(Seq(StructField("id", ByteType), StructField("vector", VectorType)))) + case _: ShortType => + ( + runTyped[Short](data).map(f => Row(f._1, Vectors.dense(f._2))), + StructType(Seq(StructField("id", ShortType), StructField("vector", VectorType)))) + case _: IntegerType => + ( + runTyped[Int](data).map(f => Row(f._1, Vectors.dense(f._2))), + StructType(Seq(StructField("id", IntegerType), StructField("vector", VectorType)))) + case _: LongType => + ( + runTyped[Long](data).map(f => Row(f._1, Vectors.dense(f._2))), + StructType(Seq(StructField("id", LongType), StructField("vector", VectorType)))) + case _ => + throw new GraphFramesUnsupportedVertexTypeException( + s"Hash2vec supports only string or numeric types of elements but gor ${elDataType.toString()}") + } + + spark.createDataFrame(rowRDD, schema).groupBy("id").agg(Summarizer.sum(col("vector"))) + } + + @nowarn + private def runTyped[T: ClassTag](data: DataFrame): RDD[(T, Array[Double])] = { + data + .select(col(sequenceCol)) + .rdd + .map(_.getAs[Seq[T]](0)) + .repartition(numPartitions) + .mapPartitions(processPartition[T]) + } + + private def processPartition[T](iter: Iterator[Seq[T]]): Iterator[(T, Array[Double])] = { + val localVocab = new java.util.concurrent.ConcurrentHashMap[T, Array[Double]]() + + for (seq <- iter) { + val currentSeqSize = seq.length + for (idx <- (0 until currentSeqSize)) { + val currentWord = seq(idx) + if (!localVocab.containsKey(currentWord)) { + localVocab.put(currentWord, Array.fill(embeddingsDim)(0.0)) + } + val context = ((idx - contextSize) to (idx + contextSize)).filter(i => + (i >= 0) && (i < currentSeqSize) && (i != idx)) + for (cIdx <- context) { + val word = seq(cIdx) + val weight = weightFunction(math.abs(cIdx - idx)) + val sign = 2.0 * signHash(word) - 1.0 + val embeddingIdx = valueHash(word) + + val currentEmbedding = localVocab.get(currentWord) + currentEmbedding(embeddingIdx) += sign * weight + } + } + } + + localVocab + .entrySet() + .asScala + .map(entry => (entry.getKey(), entry.getValue())) + .iterator + } +} diff --git a/core/src/main/scala/org/graphframes/examples/EmbeddingsExample.scala b/core/src/main/scala/org/graphframes/examples/EmbeddingsExample.scala new file mode 100644 index 000000000..a4cd1ae74 --- /dev/null +++ b/core/src/main/scala/org/graphframes/examples/EmbeddingsExample.scala @@ -0,0 +1,74 @@ +package org.graphframes.examples + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel +import org.graphframes.GraphFrame +import org.graphframes.embeddings.Hash2Vec +import org.graphframes.rw.RandomWalkWithRestart + +import java.nio.file.* + +object EmbeddingsExample { + def main(args: Array[String]): Unit = { + if (args.length == 0) { + throw new RuntimeException("expected one arg") + } + + val filePath = Paths.get(args(0)) + val sparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("GraphFramesBenchmarks") + .set("spark.sql.shuffle.partitions", s"${Runtime.getRuntime.availableProcessors() * 2}") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val spark = SparkSession.builder().config(sparkConf).getOrCreate() + val context = spark.sparkContext + context.setLogLevel("ERROR") + context.setCheckpointDir("/tmp/graphframes-checkpoints") + + val edges = spark.read + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(StructType(Seq(StructField("src", LongType), StructField("dst", LongType)))) + .load(filePath.toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER) + println() + println(s"Read edges: ${edges.count()}") + + val vertices = + edges + .select(col("src").alias("id")) + .union(edges.select(col("dst").alias("id"))) + .distinct() + .persist(StorageLevel.MEMORY_AND_DISK_SER) + println(s"Read vertices: ${vertices.count()}") + + println("Run random walks...") + val graph = GraphFrame(vertices, edges) + val rwBuilder = + new RandomWalkWithRestart() + .onGraph(graph) + .setRestartProbability(0.2) + .setGlobalSeed(42) + .setTemporaryPrefix("rw-test-data") + val walks = rwBuilder.run().persist(StorageLevel.MEMORY_AND_DISK_SER) + + println(s"Generated ${walks.count()} random walks") + println("Checkpointing walks") + + // manual checkpointing + walks.write.mode("overwrite").format("parquet").save("rw-test") + val checkpointedWalks = spark.read.parquet("rw-test") + + println("Learn embeddings") + + val embeddings = new Hash2Vec().setEmbeddingsDim(512).run(checkpointedWalks) + embeddings.write.mode("overwrite").format("parquet").save("embeddings") + } +} diff --git a/core/src/main/scala/org/graphframes/examples/RWExample.scala b/core/src/main/scala/org/graphframes/examples/RWExample.scala new file mode 100644 index 000000000..363751899 --- /dev/null +++ b/core/src/main/scala/org/graphframes/examples/RWExample.scala @@ -0,0 +1,62 @@ +package org.graphframes.examples + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel +import org.graphframes.GraphFrame +import org.graphframes.rw.RandomWalkWithRestart + +import java.nio.file.* + +object RWExample { + def main(args: Array[String]): Unit = { + if (args.length == 0) { + throw new RuntimeException("expected one arg") + } + + val filePath = Paths.get(args(0)) + val sparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("GraphFramesBenchmarks") + .set("spark.sql.shuffle.partitions", s"${Runtime.getRuntime.availableProcessors() * 2}") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val spark = SparkSession.builder().config(sparkConf).getOrCreate() + val context = spark.sparkContext + context.setLogLevel("ERROR") + context.setCheckpointDir("/tmp/graphframes-checkpoints") + + val edges = spark.read + .format("csv") + .option("header", "false") + .option("delimiter", " ") + .schema(StructType(Seq(StructField("src", LongType), StructField("dst", LongType)))) + .load(filePath.toString()) + .persist(StorageLevel.MEMORY_AND_DISK_SER) + println() + println(s"Read edges: ${edges.count()}") + + val vertices = + edges + .select(col("src").alias("id")) + .union(edges.select(col("dst").alias("id"))) + .distinct() + .persist(StorageLevel.MEMORY_AND_DISK_SER) + println(s"Read vertices: ${vertices.count()}") + + val graph = GraphFrame(vertices, edges) + val rwBuilder = + new RandomWalkWithRestart() + .onGraph(graph) + .setRestartProbability(0.2) + .setGlobalSeed(42) + .setTemporaryPrefix("rw-test-data") + val walks = rwBuilder.run() + + walks.write.mode("overwrite").format("parquet").save("rw-test") + } +} diff --git a/core/src/main/scala/org/graphframes/exceptions.scala b/core/src/main/scala/org/graphframes/exceptions.scala index cb7d76289..0584a618e 100644 --- a/core/src/main/scala/org/graphframes/exceptions.scala +++ b/core/src/main/scala/org/graphframes/exceptions.scala @@ -43,3 +43,7 @@ class InvalidPropertyGroupException(message: String) extends Exception(message) * A descriptive error message providing details about why the graph operation is invalid. */ class InvalidGraphException(message: String) extends Exception(message) + +class GraphFramesW2VException(message: String) extends Exception(message) + +class GraphFramesUnsupportedVertexTypeException(message: String) extends Exception(message) diff --git a/core/src/main/scala/org/graphframes/rw/RandomWalkBase.scala b/core/src/main/scala/org/graphframes/rw/RandomWalkBase.scala new file mode 100644 index 000000000..223468762 --- /dev/null +++ b/core/src/main/scala/org/graphframes/rw/RandomWalkBase.scala @@ -0,0 +1,314 @@ +package org.graphframes.rw + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.functions.array_union +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.udaf +import org.apache.spark.sql.graphframes.expressions.ReservoirSamplingAgg +import org.apache.spark.sql.types.ByteType +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.ShortType +import org.apache.spark.sql.types.StringType +import org.graphframes.GraphFrame +import org.graphframes.GraphFramesUnsupportedVertexTypeException +import org.graphframes.Logging +import org.graphframes.WithIntermediateStorageLevel + +import scala.util.Random + +/** + * Base trait for implementing random walk algorithms on graph data. Provides common functionality + * for generating random walks across a graph structure. + */ +trait RandomWalkBase extends Serializable with Logging with WithIntermediateStorageLevel { + + /** Maximum number of neighbors to consider per vertex during random walks. */ + protected var maxNbrs: Int = 50 + + /** GraphFrame on which random walks are performed. */ + protected var graph: GraphFrame = null + + /** Number of random walks to generate per node. */ + protected var numWalksPerNode: Int = 5 + + /** Size of each batch in the random walk process. */ + protected var batchSize: Int = 10 + + /** Number of batches to run in the random walk process. */ + protected var numBatches: Int = 5 + + /** Whether to respect edge direction in the graph (true for directed graphs). */ + protected var useEdgeDirection: Boolean = false + + /** Global random seed for reproducibility. */ + protected var globalSeed: Long = 42L + + /** Optional prefix for temporary storage during random walks. */ + protected var temporaryPrefix: Option[String] = None + + /** Unique identifier for the current random walk run. */ + protected var runID: String = "" + + /** + * Sets the graph to perform random walks on. + * + * @param graph + * the GraphFrame to run random walks on + * @return + * this RandomWalkBase instance for chaining + */ + def onGraph(graph: GraphFrame): this.type = { + this.graph = graph + this + } + + /** + * Sets the temporary prefix for storing intermediate results. + * + * @param value + * the prefix string + * @return + * this RandomWalkBase instance for chaining + */ + def setTemporaryPrefix(value: String): this.type = { + temporaryPrefix = Some(value) + this + } + + /** + * Sets the maximum number of neighbors per vertex. + * + * @param value + * the max number of neighbors + * @return + * this RandomWalkBase instance for chaining + */ + def setMaxNbrsPerVertex(value: Int): this.type = { + maxNbrs = value + this + } + + /** + * Sets the number of walks per node. + * + * @param value + * number of walks + * @return + * this RandomWalkBase instance for chaining + */ + def setNumWalksPerNode(value: Int): this.type = { + numWalksPerNode = value + this + } + + /** + * Sets the batch size. + * + * @param value + * batch size + * @return + * this RandomWalkBase instance for chaining + */ + def setBatchSize(value: Int): this.type = { + batchSize = value + this + } + + /** + * Sets the number of batches. + * + * @param value + * number of batches + * @return + * this RandomWalkBase instance for chaining + */ + def setNumBatches(value: Int): this.type = { + numBatches = value + this + } + + /** + * Sets whether to use edge direction. + * + * @param value + * true if the graph is directed + * @return + * this RandomWalkBase instance for chaining + */ + def setUseEdgeDirection(value: Boolean): this.type = { + useEdgeDirection = value + this + } + + /** + * Sets the global random seed. + * + * @param value + * the seed value + * @return + * this RandomWalkBase instance for chaining + */ + def setGlobalSeed(value: Long): this.type = { + globalSeed = value + this + } + + /** + * Generates a temporary path for a given iteration. + * + * @param iter + * iteration number + * @return + * path string + */ + private def iterationTmpPath(iter: Int): String = if (temporaryPrefix.get.endsWith("/")) { + s"${temporaryPrefix.get}${runID}_batch_${iter}" + } else { + s"${temporaryPrefix.get}/${runID}_batch_${iter}" + } + + /** + * Executes the random walk algorithm on the set graph. + * + * @return + * DataFrame containing the random walks + */ + def run(): DataFrame = { + if (graph == null) { + throw new IllegalArgumentException("Graph is not set") + } + if (temporaryPrefix.isEmpty) { + throw new IllegalArgumentException("Temporary prefix is required for random walks.") + } + runID = java.util.UUID.randomUUID().toString + logInfo(s"Starting random walk with runID: $runID") + val iterationsRng = new Random() + iterationsRng.setSeed(globalSeed) + val spark = graph.vertices.sparkSession + + for (i <- 1 to numBatches) { + logInfo(s"Starting batch $i of $numBatches") + val iterSeed = iterationsRng.nextLong() + val preparedGraph = prepareGraph() + val prevIterationDF = if (i == 1) { None } + else { + Some(spark.read.parquet(iterationTmpPath(i - 1))) + } + val iterationResult: DataFrame = runIter(preparedGraph, prevIterationDF, iterSeed) + iterationResult.write.parquet(iterationTmpPath(i)) + } + + logInfo("Finished all batches, merging results.") + var result = spark.read.parquet(iterationTmpPath(1)) + + for (i <- 2 to numBatches) { + val tmpDF = spark.read + .parquet(iterationTmpPath(i)) + .withColumnRenamed(RandomWalkBase.rwColName, "toMerge") + result = result + .join(tmpDF, Seq(RandomWalkBase.walkIdCol)) + .select( + col(RandomWalkBase.walkIdCol), + array_union(col(RandomWalkBase.rwColName), col("toMerge")) + .alias(RandomWalkBase.rwColName)) + } + result = result.persist(intermediateStorageLevel) + + val cnt = result.count() + resultIsPersistent() + logInfo(s"$cnt random walks are returned") + result + } + + /** + * Prepares the graph for random walk by limiting neighbors and handling direction. + * + * @return + * prepared GraphFrame + */ + protected def prepareGraph(): GraphFrame = { + val preAggs = if (useEdgeDirection) { + graph.edges + .select(col(GraphFrame.SRC), col(GraphFrame.DST)) + .groupBy(col(GraphFrame.SRC).alias(GraphFrame.ID)) + } else { + graph.edges + .select(GraphFrame.SRC, GraphFrame.DST) + .union(graph.edges.select(GraphFrame.DST, GraphFrame.SRC)) + .distinct() + .groupBy(col(GraphFrame.SRC).alias(GraphFrame.ID)) + } + + val vertices = graph.vertices.schema(GraphFrame.ID).dataType match { + case StringType => + preAggs.agg( + udaf(ReservoirSamplingAgg[java.lang.String](maxNbrs), Encoders.STRING) + .apply(col(GraphFrame.DST)) + .alias(RandomWalkBase.nbrsColName)) + case ShortType => + preAggs.agg( + udaf(ReservoirSamplingAgg[java.lang.Short](maxNbrs), Encoders.SHORT) + .apply(col(GraphFrame.DST)) + .alias(RandomWalkBase.nbrsColName)) + case ByteType => + preAggs.agg( + udaf(ReservoirSamplingAgg[java.lang.Byte](maxNbrs), Encoders.BYTE) + .apply(col(GraphFrame.DST)) + .alias(RandomWalkBase.nbrsColName)) + case IntegerType => + preAggs.agg( + udaf(ReservoirSamplingAgg[java.lang.Integer](maxNbrs), Encoders.INT) + .apply(col(GraphFrame.DST)) + .alias(RandomWalkBase.nbrsColName)) + case LongType => + preAggs.agg( + udaf(ReservoirSamplingAgg[java.lang.Long](maxNbrs), Encoders.LONG) + .apply(col(GraphFrame.DST)) + .alias(RandomWalkBase.nbrsColName)) + case _ => throw new GraphFramesUnsupportedVertexTypeException("unsupported vertex type") + } + + val edges = graph.edges + .select(GraphFrame.SRC, GraphFrame.DST) + .join(vertices, col(GraphFrame.SRC) === col(GraphFrame.ID)) + .drop(GraphFrame.ID) + .join(vertices, col(GraphFrame.DST) === col(GraphFrame.ID)) + .drop(GraphFrame.ID) + + GraphFrame(vertices, edges) + } + + /** + * Runs a single iteration of the random walk. + * + * @param graph + * prepared graph + * @param prevIterationDF + * DataFrame from previous iteration (if any) + * @param iterSeed + * seed for this iteration + * @return + * DataFrame result of this iteration + */ + protected def runIter( + graph: GraphFrame, + prevIterationDF: Option[DataFrame], + iterSeed: Long): DataFrame +} + +object RandomWalkBase { + + /** Column name for the random walk array. */ + val rwColName: String = "random_walk" + + /** Column name for the unique walk ID. */ + val walkIdCol: String = "random_walk_uuid" + + /** Column name for neighbors list. */ + val nbrsColName: String = "random_walk_nbrs" + + /** Column name for the current visiting vertex. */ + val currVisitingVertexColName: String = "random_walk_curr_vertex" +} diff --git a/core/src/main/scala/org/graphframes/rw/RandomWalkWithRestart.scala b/core/src/main/scala/org/graphframes/rw/RandomWalkWithRestart.scala new file mode 100644 index 000000000..01c7ec801 --- /dev/null +++ b/core/src/main/scala/org/graphframes/rw/RandomWalkWithRestart.scala @@ -0,0 +1,80 @@ +package org.graphframes.rw + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.* +import org.graphframes.GraphFrame + +/** + * An implementation of random walk with restart. At each step of the walk, there is a probability + * (defined by restartProbability) to reset the walk to the original starting node, otherwise the + * walk continues to a random neighbor. + */ +/** + * An implementation of random walk with restart. At each step of the walk, there is a probability + * (defined by restartProbability) to reset the walk to the original starting node, otherwise the + * walk continues to a random neighbor. + */ +class RandomWalkWithRestart extends RandomWalkBase { + + /** The probability of restarting the walk at each step (resets to starting node). */ + private var restartProbability: Double = 0.1 + + /** + * Sets the restart probability for the random walk. + * + * @param value + * the probability value (between 0.0 and 1.0) + * @return + * this RandomWalkWithRestart instance for chaining + */ + def setRestartProbability(value: Double): this.type = { + restartProbability = value + this + } + + override protected def runIter( + graph: GraphFrame, + prevIterationDF: Option[DataFrame], + iterSeed: Long): DataFrame = { + val neighbors = graph.vertices.select(col(GraphFrame.ID), col(RandomWalkBase.nbrsColName)) + var walks = if (prevIterationDF.isEmpty) { + graph.vertices.select( + col(GraphFrame.ID).alias("startingNode"), + col(GraphFrame.ID).alias(RandomWalkBase.currVisitingVertexColName), + explode( + when( + array_size(col(RandomWalkBase.nbrsColName)) > lit(0), + array((0 until numWalksPerNode).map(_ => uuid()): _*)).otherwise(array())) + .alias(RandomWalkBase.walkIdCol), + array(col(GraphFrame.ID)).alias(RandomWalkBase.rwColName)) + } else { + prevIterationDF.get.select( + col("startingNode"), + col(RandomWalkBase.currVisitingVertexColName), + col(RandomWalkBase.walkIdCol), + array(col(RandomWalkBase.currVisitingVertexColName)).alias(RandomWalkBase.rwColName)) + } + + for (_ <- (0 until batchSize)) { + walks = walks + .join( + neighbors, + col(GraphFrame.ID) === col(RandomWalkBase.currVisitingVertexColName), + "left") + .withColumn("doRestart", rand() <= lit(restartProbability)) + .withColumn( + "nextNode", + when(col("doRestart"), col("startingNode")).otherwise( + element_at(shuffle(col(RandomWalkBase.nbrsColName)), 1))) + .select( + col(RandomWalkBase.walkIdCol), + col("startingNode"), + col("nextNode").alias(RandomWalkBase.currVisitingVertexColName), + array_append( + col(RandomWalkBase.rwColName), + col(RandomWalkBase.currVisitingVertexColName)).alias(RandomWalkBase.rwColName)) + } + + walks + } +} diff --git a/core/src/test/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAggSuite.scala b/core/src/test/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAggSuite.scala new file mode 100644 index 000000000..01c1ad120 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/sql/graphframes/expressions/ReservoirSamplingAggSuite.scala @@ -0,0 +1,141 @@ +package org.apache.spark.sql.graphframes.expressions + +import org.scalatest.funsuite.AnyFunSuite + +import scala.collection.mutable.ArrayBuffer + +class ReservoirSamplingAggSuite extends AnyFunSuite { + + test("zero returns empty reservoir with zero elements") { + val agg = new ReservoirSamplingAgg[Int](3) + val res = agg.zero + assert(res.seq.isEmpty) + assert(res.elements == 0) + } + + test("reduce adds elements when below size") { + val agg = new ReservoirSamplingAgg[Int](3) + var res = agg.zero + res = agg.reduce(res, 1) + assert(res.seq == ArrayBuffer(1)) + assert(res.elements == 1) + res = agg.reduce(res, 2) + assert(res.seq == ArrayBuffer(1, 2)) + assert(res.elements == 2) + res = agg.reduce(res, 3) + assert(res.seq == ArrayBuffer(1, 2, 3)) + assert(res.elements == 3) + } + + test("reduce replaces randomly when at size with fixed seed") { + val agg = new ReservoirSamplingAgg[Int](2) + // Create a full reservoir with fixed rng + var res = agg.zero + res = agg.reduce(res, 1) + res = agg.reduce(res, 2) + // Now res.seq = [1,2], elements=2 + val fixedRes = res.copy() + + // Add third element + val res3 = agg.reduce(fixedRes, 3) + assert(res3.elements == 3) + assert(res3.seq.length == 2) + } + + test("merge two empty reservoirs") { + val agg = new ReservoirSamplingAgg[Int](5) + val r1 = agg.zero + val r2 = agg.zero + val merged = agg.merge(r1, r2) + assert(merged.elements == 0) + assert(merged.seq.length == 0) + } + + test("merge two partial reservoirs below size") { + val agg = new ReservoirSamplingAgg[Int](5) + var r1 = agg.zero + r1 = agg.reduce(r1, 1) + var r2 = agg.zero + r2 = agg.reduce(r2, 2) + val merged = agg.merge(r1, r2) + assert(merged.elements == 2) + assert(merged.seq.toSet == Set(1, 2)) + } + + test("merge partial and full reservoirs with fixed seed") { + val agg = new ReservoirSamplingAgg[Int](1) + var left = agg.zero + left = agg.reduce(left, 10) + left = left.copy() + + var right = agg.zero + right = agg.reduce(right, 20) + val merged = agg.merge(left, right) + assert(merged.elements == 2) + assert(merged.seq.length == 1) + } + + test("merge two full reservoirs with fixed seed") { + val agg = new ReservoirSamplingAgg[Int](2) + val r1 = Reservoir(ArrayBuffer(1, 2), 5) + val r2 = Reservoir(ArrayBuffer(3, 4), 5) + val merged = agg.merge(r1, r2) + assert(merged.elements == 10) + assert(merged.seq.length == 2) + } + + test("finish returns the sequence") { + val agg = new ReservoirSamplingAgg[Int](3) + var res = agg.zero + res = agg.reduce(res, 1) + res = agg.reduce(res, 2) + val seq = agg.finish(res) + assert(seq == Seq(1, 2)) + } + + test("uniformity of sampling") { + // WARNING! + // this test is slightly non determenistic + // in a very rare case (1 from 50) it may fail + // so just re-run it. + val numElements = 5000 + val numSamples = 5000 + val sampleSize = 500 + val sequence = (1 to numElements).toArray + + // Count frequencies of each element across all samples + val frequencyMap = scala.collection.mutable.Map[Int, Int]() + for (i <- 0 until numElements) { + frequencyMap += (i + 1 -> 0) + } + + // Perform multiple samplings + for (_ <- 0 until numSamples) { + val agg = new ReservoirSamplingAgg[Int](sampleSize) + var res = agg.zero + + // Fill reservoir with all elements in sequence + for (element <- sequence) { + res = agg.reduce(res, element) + } + + // Collect sampled elements + val sampled = agg.finish(res) + for (element <- sampled) { + frequencyMap(element) += 1 + } + } + + // Check uniformity - each element should be sampled roughly the same number of times + val expectedFreq = numSamples * sampleSize.toDouble / numElements + val tolerance = 0.2 // 20% tolerance + val minExpected = expectedFreq * (1 - tolerance) + val maxExpected = expectedFreq * (1 + tolerance) + + for ((element, count) <- frequencyMap) { + assert( + count >= minExpected && count <= maxExpected, + s"Element $element was sampled $count times, expected between $minExpected and $maxExpected") + } + } +}