diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index b601ed862..6b7d1fae6 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -465,7 +465,6 @@ class GraphFrame private ( */ def find(pattern: String): DataFrame = { val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]-(>?)\((\w+)\)""".r - val UndirectedPattern = """\((\w+)\)-\[(\w*)\]-\((\w+)\)""".r pattern match { case VarLengthPattern(src, name, min, max, direction, dst) => @@ -500,17 +499,6 @@ class GraphFrame private ( val ret = (out ++ in).reduce((a, b) => a.unionByName(b, allowMissingColumns = true)) ret.orderBy("_hop", "_direction") - case UndirectedPattern(src, name, dst) => - val out: DataFrame = findAugmentedPatterns(s"($src)-[$name]->($dst)") - .withColumn("_pattern", lit(s"($src)-[$name]->($dst)")) - .withColumn("_direction", lit("out")) - val in: DataFrame = findAugmentedPatterns(s"($src)<-[$name]-($dst)") - .withColumn("_pattern", lit(s"($src)<-[$name]-($dst)")) - .withColumn("_direction", lit("in")) - - val ret = out.unionByName(in) - ret.orderBy("_direction") - case _ => findAugmentedPatterns(pattern) } @@ -1203,6 +1191,16 @@ object GraphFrame extends Serializable with Logging { private def eSrcId(name: String): String = prefixWithName(name, SRC) private def eDstId(name: String): String = prefixWithName(name, DST) + private def maybeUnion(aOpt: Option[DataFrame], bOpt: Option[DataFrame]): Option[DataFrame] = { + (aOpt, bOpt) match { + case (Some(a), Some(b)) => + Some(a.unionByName(b, allowMissingColumns = true).orderBy("_direction")) + case (Some(a), None) => Some(a) + case (None, Some(b)) => Some(b) + case (None, None) => None + } + } + private def maybeCrossJoin(aOpt: Option[DataFrame], b: DataFrame): DataFrame = { aOpt match { case Some(a) => a.crossJoin(b) @@ -1227,6 +1225,8 @@ object GraphFrame extends Serializable with Logging { private def seen1(v: NamedVertex, pattern: Pattern): Boolean = pattern match { case Negation(edge) => seen1(v, edge) + case UndirectedEdge(edge) => + seen1(v, edge) case AnonymousEdge(src, dst) => seen1(v, src) || seen1(v, dst) case NamedEdge(_, src, dst) => @@ -1271,6 +1271,57 @@ object GraphFrame extends Serializable with Logging { (Some(maybeCrossJoin(prev, nestV(name))), prevNames :+ name) } + case UndirectedEdge(edge) => + val srcName: String = edge match { + case NamedEdge(_, NamedVertex(n), _) => n + case AnonymousEdge(NamedVertex(n), _) => n + case _ => "" + } + val dstName: String = edge match { + case NamedEdge(_, _, NamedVertex(n)) => n + case AnonymousEdge(_, NamedVertex(n)) => n + case _ => "" + } + val edgeName: String = edge match { + case NamedEdge(n, _, _) => n + case _ => "" + } + + val patternStr: String = s"($srcName)-[$edgeName]->($dstName)" + val reversedPatternStr: String = s"($srcName)<-[$edgeName]-($dstName)" + + val reversedEdge: Pattern = { + edge match { + case e: NamedEdge => + e.copy(src = e.dst, dst = e.src) + case e: AnonymousEdge => + e.copy(src = e.dst, dst = e.src) + case _ => edge + } + } + + val (dfIn, _) = findIncremental(gf, prevPatterns, prev, prevNames, reversedEdge) + val (dfOut, names) = findIncremental(gf, prevPatterns, prev, prevNames, edge) + + val df1 = dfIn match { + case Some(d) => + Some( + d.withColumn("_pattern", lit(reversedPatternStr)) + .withColumn("_direction", lit("in"))) + case None => None + } + + val df2 = dfOut match { + case Some(d) => + Some( + d.withColumn("_pattern", lit(patternStr)) + .withColumn("_direction", lit("out"))) + case None => None + } + + val df = maybeUnion(df1, df2) + (df, names :+ "_pattern" :+ "_direction") + case NamedEdge(name, AnonymousVertex, AnonymousVertex) => val eRen = nestE(name) (Some(maybeCrossJoin(prev, eRen)), prevNames :+ name) @@ -1376,6 +1427,7 @@ object GraphFrame extends Serializable with Logging { prev match { case Some(p) => val (df, names) = findIncremental(gf, prevPatterns, Some(p), prevNames, edge) + // TODO: _pattern. _direction columns should be ignored if it is impacting (df.map(result => p.except(result)), names) case None => throw new InvalidPatternException diff --git a/core/src/main/scala/org/graphframes/pattern/patterns.scala b/core/src/main/scala/org/graphframes/pattern/patterns.scala index 846f12803..9e08a5107 100644 --- a/core/src/main/scala/org/graphframes/pattern/patterns.scala +++ b/core/src/main/scala/org/graphframes/pattern/patterns.scala @@ -31,13 +31,15 @@ private[graphframes] object PatternParser extends RegexParsers { private val anonymousVertex: Parser[Vertex] = "" ^^ { _ => AnonymousVertex } private val vertex: Parser[Vertex] = "(" ~> (vertexName | anonymousVertex) <~ ")" private val namedEdge: Parser[Edge] = - vertex ~ "-" ~ "[" ~ "[a-zA-Z0-9_]+".r ~ "]" ~ "->" ~ vertex ^^ { + vertex ~ "-" ~ "[" ~ "[a-zA-Z0-9_]+".r ~ "]" ~ ("->" | "-") ~ vertex ^^ { case src ~ "-" ~ "[" ~ name ~ "]" ~ "->" ~ dst => NamedEdge(name, src, dst) + case src ~ "-" ~ "[" ~ name ~ "]" ~ "-" ~ dst => UndirectedEdge(NamedEdge(name, src, dst)) case _ => throw new GraphFramesUnreachableException() } val anonymousEdge: Parser[Edge] = - vertex ~ "-" ~ "[" ~ "]" ~ "->" ~ vertex ^^ { + vertex ~ "-" ~ "[" ~ "]" ~ ("->" | "-") ~ vertex ^^ { case src ~ "-" ~ "[" ~ "]" ~ "->" ~ dst => AnonymousEdge(src, dst) + case src ~ "-" ~ "[" ~ "]" ~ "-" ~ dst => UndirectedEdge(AnonymousEdge(src, dst)) case _ => throw new GraphFramesUnreachableException() } private val edge: Parser[Edge] = namedEdge | anonymousEdge @@ -157,6 +159,8 @@ private[graphframes] object Pattern { case AnonymousEdge(src, dst) => addVertex(src) addVertex(dst) + case UndirectedEdge(edge) => + addEdge(edge) } patterns.foreach { @@ -171,6 +175,15 @@ private[graphframes] object Pattern { "Motif finding does not support completely " + "anonymous negated edges !()-[]->(). Users can check for 0 edges in the graph " + "using the edges DataFrame.") + case e @ UndirectedEdge(edge) => + edge match { + case AnonymousEdge(AnonymousVertex, AnonymousVertex) => + throw new InvalidParseException( + "Motif finding does not support completely " + + "anonymous negated edges !()-[]-(). Users can check for the existence of edges in the " + + "graph using the edges DataFrame.") + case _ => addEdge(e) + } case e @ AnonymousEdge(_, _) => addEdge(e) } @@ -179,6 +192,15 @@ private[graphframes] object Pattern { "Motif finding does not support completely " + "anonymous edges ()-[]->(). Users can check for the existence of edges in the " + "graph using the edges DataFrame.") + case e @ UndirectedEdge(edge) => + edge match { + case AnonymousEdge(AnonymousVertex, AnonymousVertex) => + throw new InvalidParseException( + "Motif finding does not support completely " + + "anonymous edges ()-[]-(). Users can check for the existence of edges in the " + + "graph using the edges DataFrame.") + case _ => addEdge(e) + } case e @ AnonymousEdge(_, _) => addEdge(e) case e @ NamedEdge(_, _, _) => @@ -220,6 +242,10 @@ private[graphframes] object Pattern { def findNamedElementsHelper(pattern: Pattern): Unit = pattern match { case Negation(child) => findNamedElementsHelper(child) + case UndirectedEdge(child) => + findNamedElementsHelper(child) + elementSet += "_pattern" + elementSet += "_direction" case AnonymousVertex => // pass case NamedVertex(name) => if (!elementSet.contains(name)) { @@ -252,6 +278,8 @@ private[graphframes] case class NamedVertex(name: String) extends Vertex private[graphframes] sealed trait Edge extends Pattern +private[graphframes] case class UndirectedEdge(edge: Edge) extends Edge + private[graphframes] case class AnonymousEdge(src: Vertex, dst: Vertex) extends Edge private[graphframes] case class NamedEdge(name: String, src: Vertex, dst: Vertex) extends Edge diff --git a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala index 5a91c4fe7..9fc4979d6 100644 --- a/core/src/test/scala/org/graphframes/PatternMatchSuite.scala +++ b/core/src/test/scala/org/graphframes/PatternMatchSuite.scala @@ -675,6 +675,43 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext { compareResultToExpected(res, expected) } + test("undirected information column") { + val res1 = g + .find("(u)-[e1]-(v)") + .where("u.id == 0") + .select("_pattern", "_direction") + .collect() + .toSet + + val expected1 = Set(Row("(u)<-[e1]-(v)", "in"), Row("(u)-[e1]->(v)", "out")) + + compareResultToExpected(res1, expected1) + + val res2 = g + .find("(u)-[]-(v)") + .where("u.id == 0") + .select("_pattern", "_direction") + .collect() + .toSet + + val expected2 = Set(Row("(u)<-[]-(v)", "in"), Row("(u)-[]->(v)", "out")) + + compareResultToExpected(res2, expected2) + } + + test("undirected edge within a chain") { + val res = g + .find("(u)-[]-(v);(v)-[]->(k)") + .where("u.id == 0") + .select("u.id", "v.id", "k.id") + .collect() + .toSet + + val expected = Set(Row(0L, 1L, 2L), Row(0L, 1L, 0L), Row(0L, 2L, 0L), Row(0L, 2L, 3L)) + + compareResultToExpected(res, expected) + } + test("undirected with edge name") { val res = g .find("(u)-[e]-(v)") diff --git a/core/src/test/scala/org/graphframes/pattern/PatternSuite.scala b/core/src/test/scala/org/graphframes/pattern/PatternSuite.scala index 1aa93e061..f8bc918c0 100644 --- a/core/src/test/scala/org/graphframes/pattern/PatternSuite.scala +++ b/core/src/test/scala/org/graphframes/pattern/PatternSuite.scala @@ -95,6 +95,18 @@ class PatternSuite extends SparkFunSuite { AnonymousEdge(NamedVertex("_v9"), NamedVertex("v")))) } + test("good parses - undirected pattern") { + assert( + Pattern.parse("(u)-[e]-(v)") === + Seq(UndirectedEdge(NamedEdge("e", NamedVertex("u"), NamedVertex("v"))))) + + assert( + Pattern.parse("(u)-[e]-(v);(v)-[]-(k)") === + Seq( + UndirectedEdge(NamedEdge("e", NamedVertex("u"), NamedVertex("v"))), + UndirectedEdge(AnonymousEdge(NamedVertex("v"), NamedVertex("k"))))) + } + test("rewrite incomming edges") { assert(Pattern.rewriteIncomingEdges("(u)<-[e]-(v);") === "(v)-[e]->(u)") assert(Pattern.rewriteIncomingEdges("!(u)<-[e]-(v);") === "!(v)-[e]->(u)") @@ -172,6 +184,17 @@ class PatternSuite extends SparkFunSuite { Pattern.parse("!()-[]->()") } } + withClue("Failed to catch parse error with completely anonymous undirected edge ()-[]-()") { + intercept[InvalidParseException] { + Pattern.parse("()-[]-()") + } + } + withClue( + "Failed to catch parse error with completely anonymous negated and undirected edge !()-[]-()") { + intercept[InvalidParseException] { + Pattern.parse("!()-[]-()") + } + } withClue("Failed to catch parse error with reused element name") { intercept[InvalidParseException] { Pattern.parse("(a)-[]->(b); ()-[a]->()") diff --git a/docs/src/04-user-guide/04-motif-finding.md b/docs/src/04-user-guide/04-motif-finding.md index 01cc93f78..951770907 100644 --- a/docs/src/04-user-guide/04-motif-finding.md +++ b/docs/src/04-user-guide/04-motif-finding.md @@ -43,11 +43,11 @@ DSL for expressing structural patterns: Restrictions: -* Motifs are not allowed to contain edges without any named elements: `"()-[]->()"` and `"!()-[]->()"` are prohibited terms. +* Motifs are not allowed to contain edges without any named elements: `"()-[]->()"`, `"!()-[]->()"`, `"()-[]-()"`, and `"!()-[]-()"` are prohibited terms. * Motifs are not allowed to contain named edges within negated terms (since these named edges would never appear within results). E.g., `"!(a)-[ab]->(b)"` is invalid, but `"!(a)-[]->(b)"` is valid. * Negation is not supported for the variable-length pattern, bidirectional pattern and undirected pattern: `"!(a)-[*1..3]->(b)"`, `"!(a)<-[]->(b)"` and `"!(a)-[]-(b)"` are not allowed. * Unbounded length patten is not supported: `"(a)-[*..3]->(b)"` and `"(a)-[*1..]->(b)"` are not allowed. -* You cannot join additional edges with the variable length pattern: `"(a)-[*1..3]-(b);(b)-[]-(c)"`is not valid. +* You cannot join additional edges with quantified length patterns: `"(a)-[*3]->(b);(b)-[]->(c)"` and `"(a)-[*1..3]->(b);(b)-[]->(c)"` are not allowed. More complex queries, such as queries which operate on vertex or edge attributes, can be expressed by applying filters to the result `DataFrame`.