8000 chore: force python code style check in CI by SemyonSinchenko · Pull Request #527 · graphframes/graphframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ jobs:
working-directory: ./python
run: |
poetry build
poetry install
poetry install --with dev
- name: Code Style
working-directory: ./python
run: |
poetry run python -m black --check graphframes
poetry run python -m flake8 graphframes
poetry run python -m isort --check graphframes
- name: Test
working-directory: ./python
run: |
Expand Down
3 changes: 1 addition & 2 deletions python/graphframes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from .graphframe import GraphFrame

__all__ = ['GraphFrame']
__all__ = ["GraphFrame"]
1 change: 1 addition & 0 deletions python/graphframes/console.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import click

from graphframes.tutorials import download


Expand Down
3 changes: 1 addition & 2 deletions python/graphframes/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from .belief_propagation import BeliefPropagation
from .graphs import Graphs

__all__ = ['BeliefPropagation', 'Graphs']
__all__ = ["BeliefPropagation", "Graphs"]
54 changes: 30 additions & 24 deletions python/graphframes/examples/belief_propagation.py
< 61F6 td class="blob-num blob-num-addition empty-cell">
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import math
from typing import Union

from pyspark.sql import SparkSession
from pyspark.sql import functions as sqlfunctions
from pyspark.sql import types

# Import subpackage examples here explicitly so that
# this module can be run directly with spark-submit.
import graphframes.examples
from graphframes import GraphFrame
from graphframes.lib import AggregateMessages as AM
from pyspark.sql import SparkSession, functions as sqlfunctions, types

__all__ = ['BeliefPropagation']
__all__ = ["BeliefPropagation"]


class BeliefPropagation:
Expand Down Expand Up @@ -61,7 +64,7 @@ class BeliefPropagation:
* Coloring the graph by assigning a color to each vertex such that no neighboring vertices
share the same color.
* In each step of BP, update all vertices of a single color. Alternate colors.
"""
""" # noqa: W605

@classmethod
def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame:
Expand All @@ -71,12 +74,12 @@ def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame:
"""
# choose colors for vertices for BP scheduling
colorG = cls._colorGraph(g)
numColors = colorG.vertices.select('color').distinct().count()
numColors = colorG.vertices.select("color").distinct().count()

# TODO: handle vertices without any edges

# initialize vertex beliefs at 0.0
gx = GraphFrame(colorG.vertices.withColumn('belief', sqlfunctions.lit(0.0)), colorG.edges)
gx = GraphFrame(colorG.vertices.withColumn("belief", sqlfunctions.lit(0.0)), colorG.edges)

# run BP for numIter iterations
for iter_ in range(numIter):
Expand All @@ -85,37 +88,40 @@ def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame:
# Send messages to vertices of the current color.
# We may send to source or destination since edges are treated as undirected.
msgForSrc = sqlfunctions.when(
AM.src['color'] == color,
AM.edge['b'] * AM.dst['belief'])
AM.src["color"] == color, AM.edge["b"] * AM.dst["belief"]
)
msgForDst = sqlfunctions.when(
AM.dst['color'] == color,
AM.edge['b'] * AM.src['belief'])
AM.dst["color"] == color, AM.edge["b"] * AM.src["belief"]
)
# numerically stable sigmoid
logistic = sqlfunctions.udf(cls._sigmoid, returnType=types.DoubleType())
aggregates = gx.aggregateMessages(
sqlfunctions.sum(AM.msg).alias("aggMess"),
sendToSrc=msgForSrc,
sendToDst=msgForDst)
sendToDst=msgForDst,
)
v = gx.vertices
# receive messages and update beliefs for vertices of the current color
newBeliefCol = sqlfunctions.when(
(v['color'] == color) & (aggregates['aggMess'].isNotNull()),
logistic(aggregates['aggMess'] + v['a'])
).otherwise(v['belief']) # keep old beliefs for other colors
newVertices = (v
.join(aggregates, on=(v['id'] == aggregates['id']), how='left_outer')
.drop(aggregates['id']) # drop duplicate ID column (from outer join)
.withColumn('newBelief', newBeliefCol) # compute new beliefs
.drop('aggMess') # drop messages
.drop('belief') # drop old beliefs
.withColumnRenamed('newBelief', 'belief')
(v["color"] == color) & (aggregates["aggMess"].isNotNull()),
logistic(aggregates["aggMess"] + v["a"]),
).otherwise(
v["belief"]
) # keep old beliefs for other colors
newVertices = (
v.join(aggregates, on=(v["id"] == aggregates["id"]), how="left_outer")
.drop(aggregates["id"]) # drop duplicate ID column (from outer join)
.withColumn("newBelief", newBeliefCol) # compute new beliefs
.drop("aggMess") # drop messages
.drop("belief") # drop old beliefs
.withColumnRenamed("newBelief", "belief")
)
# cache new vertices using workaround for SPARK-1334
cachedNewVertices = AM.getCachedDataFrame(newVertices)
gx = GraphFrame(cachedNewVertices, gx.edges)

# Drop the "color" column from vertices
return GraphFrame(gx.vertices.drop('color'), gx.edges)
return GraphFrame(gx.vertices.drop("color"), gx.edges)

@staticmethod
def _colorGraph(g: GraphFrame) -> GraphFrame:
Expand All @@ -132,7 +138,7 @@ def _colorGraph(g: GraphFrame) -> GraphFrame:
"""

colorUDF = sqlfunctions.udf(lambda i, j: (i + j) % 2, returnType=types.IntegerType())
v = g.vertices.withColumn('color', colorUDF(sqlfunctions.col('i'), sqlfunctions.col('j')))
v = g.vertices.withColumn("color", colorUDF(sqlfunctions.col("i"), sqlfunctions.col("j")))
return GraphFrame(v, g.edges)

@staticmethod
Expand Down Expand Up @@ -164,12 +170,12 @@ def main() -> None:
results = BeliefPropagation.runBPwithGraphFrames(g, numIter)

# display beliefs
beliefs = results.vertices.select('id', 'belief')
beliefs = results.vertices.select("id", "belief")
print("Done with BP. Final beliefs after {} iterations:".format(numIter))
beliefs.show()

spark.stop()


if __name__ == '__main__':
if __name__ == "__main__":
main()
80 changes: 45 additions & 35 deletions python/graphframes/examples/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import itertools

from pyspark.sql import functions as sqlfunctions, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql import functions as sqlfunctions

from graphframes import GraphFrame

__all__ = ['Graphs']
__all__ = ["Graphs"]


class Graphs:
Expand All @@ -37,24 +38,30 @@ def __init__(self, spark: SparkSession) -> None:
def friends(self) -> GraphFrame:
"""A GraphFrame of friends in a (fake) social network."""
# Vertex DataFrame
v = self._spark.createDataFrame([
("a", "Alice", 34),
("b", "Bob", 36),
("c", "Charlie", 30),
("d", "David", 29),
("e", "Esther", 32),
("f", "Fanny", 36)
], ["id", "name", "age"])
v = self._spark.createDataFrame(
[
("a", "Alice", 34),
("b", "Bob", 36),
("c", "Charlie", 30),
("d", "David", 29),
("e", "Esther", 32),
("f", "Fanny", 36),
],
["id", "name", "age"],
)
# Edge DataFrame
e = self._spark.createDataFrame([
("a", "b", "friend"),
("b", "c", "follow"),
("c", "b", "follow"),
("f", "c", "follow"),
("e", "f", "follow"),
("e", "d", "friend"),
("d", "a", "friend")
], ["src", "dst", "relationship"])
e = self._spark.createDataFrame(
[
("a", "b", "friend"),
("b", "c", "follow"),
("c", "b", "follow"),
("f", "c", "follow"),
("e", "f", "follow"),
("e", "d", "friend"),
("d", "a", "friend"),
],
["src", "dst", "relationship"],
)
# Create a GraphFrame
return GraphFrame(v, e)

Expand Down Expand Up @@ -83,41 +90,44 @@ def gridIsingModel(self, n: int, vStd: float = 1.0, eStd: float = 1.0) -> GraphF
and "b". Edges are directed, but they should be treated as undirected in any algorithms
run on this model. Vertex IDs are of the form "i,j". E.g., vertex "1,3" is in the
second row and fourth column of the grid.
"""
""" # noqa: W605
# check param n
if n < 1:
raise ValueError(
"Grid graph must have size >= 1, but was given invalid value n = {}"
.format(n))
"Grid graph must have size >= 1, but was given invalid value n = {}".format(n)
)

# create coodinates grid
coordinates = self._spark.createDataFrame(
itertools.product(range(n), range(n)),
schema=('i', 'j'))
itertools.product(range(n), range(n)), schema=("i", "j")
)

# create SQL expression for converting coordinates (i,j) to a string ID "i,j"
# avoid Cartesian join due to SPARK-15425: use generator since n should be small
toIDudf = sqlfunctions.udf(lambda i, j: '{},{}'.format(i,j))
toIDudf = sqlfunctions.udf(lambda i, j: "{},{}".format(i, j))

# create the vertex DataFrame
# create SQL expression for converting coordinates (i,j) to a string ID "i,j"
vIDcol = toIDudf(sqlfunctions.col('i'), sqlfunctions.col('j'))
vIDcol = toIDudf(sqlfunctions.col("i"), sqlfunctions.col("j"))
# add random parameters generated from a normal distribution
seed = 12345
vertices = (coordinates.withColumn('id', vIDcol)
.withColumn('a', sqlfunctions.randn(seed) * vStd))
vertices = coordinates.withColumn("id", vIDcol).withColumn(
"a", sqlfunctions.randn(seed) * vStd
)

# create the edge DataFrame
# create SQL expression for converting coordinates (i,j+1) and (i+1,j) to string IDs
rightIDcol = toIDudf(sqlfunctions.col('i'), sqlfunctions.col('j') + 1)
downIDcol = toIDudf(sqlfunctions.col('i') + 1, sqlfunctions.col('j'))
horizontalEdges = (coordinates.filter(sqlfunctions.col('j') != n - 1)
.select(vIDcol.alias('src'), rightIDcol.alias('dst')))
verticalEdges = (coordinates.filter(sqlfunctions.col('i') != n - 1)
.select(vIDcol.alias('src'), downIDcol.alias('dst')))
rightIDcol = toIDudf(sqlfunctions.col("i"), sqlfunctions.col("j") + 1)
downIDcol = toIDudf(sqlfunctions.col("i") + 1, sqlfunctions.col("j"))
horizontalEdges = coordinates.filter(sqlfunctions.col("j") != n - 1).select(
vIDcol.alias("src"), rightIDcol.alias("dst")
)
verticalEdges = coordinates.filter(sqlfunctions.col("i") != n - 1).select(
vIDcol.alias("src"), downIDcol.alias("dst")
)
allEdges = horizontalEdges.unionAll(verticalEdges)
# add random parameters from a normal distribution
edges = allEdges.withColumn('b', sqlfunctions.randn(seed + 1) * eStd)
edges = allEdges.withColumn("b", sqlfunctions.randn(seed + 1) * eStd)

# create the GraphFrame
g = GraphFrame(vertices, edges)
Expand Down
Loading
31F2
0