import itertools
import logging
from copy import deepcopy
from typing import Any, Dict, Iterator, Optional, Set
import networkx as nx
import numpy as np
from .freezable import Freezable
from .graph_spec import GraphSpec
from .roi import Roi
logger = logging.getLogger(__name__)
[docs]
class Node(Freezable):
"""
A stucture representing each node in a Graph.
Args:
id (``int``):
A unique identifier for this Node
location (``np.ndarray``):
A numpy array containing a nodes location
Optional attrs (``dict``, str -> ``Any``):
A dictionary containing a mapping from attribute to value.
Used to store any extra attributes associated with the
Node such as color, size, etc.
Optional temporary (bool):
A tag to mark a node as temporary. Some operations such
as `trim` might make new nodes that are just biproducts
of viewing the data with a limited scope. These nodes
are only guaranteed to have an id different from those
in the same Graph, but may have conflicts if you request
multiple graphs from the same source with different rois.
"""
def __init__(
self,
id: int,
location: np.ndarray,
temporary: bool = False,
attrs: Optional[Dict[str, Any]] = None,
):
self.__attrs = attrs if attrs is not None else {}
self.attrs["id"] = id
self.location = location
# purpose is to keep track of nodes that were created during
# processing and do not have a corresponding node in the original source
self.attrs["temporary"] = temporary
self.freeze()
def __getattr__(self, attr):
if "__" not in attr:
return self.attrs[attr]
else:
return super().__getattr__(attr)
def __setattr__(self, attr, value):
if "__" not in attr:
self.attrs[attr] = value
else:
super().__setattr__(attr, value)
@property
def location(self):
location = self.attrs["location"]
return location
@location.setter
def location(self, new_location):
assert isinstance(new_location, np.ndarray)
self.attrs["location"] = new_location
@property
def id(self):
return self.attrs["id"]
@property
def original_id(self):
return self.id if not self.temporary else None
@property
def temporary(self):
return self.attrs["temporary"]
@property
def attrs(self):
return self.__attrs
@property
def all(self):
return self.attrs
@classmethod
def from_attrs(cls, attrs: Dict[str, Any]):
node_id = attrs["id"]
location = attrs["location"]
temporary = attrs.get("temporary", False)
return cls(id=node_id, location=location, temporary=temporary, attrs=attrs)
def __str__(self):
return f"Node({self.temporary}) ({self.id}) at ({self.location})"
def __repr__(self):
return str(self)
def __eq__(self, other):
return isinstance(other, Node) and self.id == other.id
def __hash__(self):
return hash(self.id)
[docs]
class Edge(Freezable):
"""
A structure representing edges in a graph.
Args:
u (``int``)
The id of the 'u' node of this edge
v (``int``)
the id of the `v` node of this edge
"""
def __init__(self, u: int, v: int, attrs: Optional[Dict[str, Any]] = None):
self.__u = u
self.__v = v
self.__attrs = attrs if attrs is not None else {}
self.freeze()
@property
def u(self):
return self.__u
@property
def v(self):
return self.__v
@property
def attrs(self):
return self.__attrs
@property
def all(self):
return self.__attrs
@classmethod
def from_attrs(cls, attrs: Dict[str, Any]):
u = attrs["u"]
v = attrs["v"]
return cls(u, v, attrs=attrs)
def __iter__(self):
return iter([self.u, self.v])
def __str__(self):
return f"({self.u}, {self.v})"
def __repr__(self):
return f"({self.u}, {self.v})"
def __eq__(self, other):
return self.u == other.u and self.v == other.v
def __hash__(self):
return hash((self.u, self.v))
def directed_eq(self, other):
return self.u == other.u and self.v == other.v
def undirected_eq(self, other):
return set([self.u, self.v]) == set([other.u, other.v])
[docs]
class Graph(Freezable):
"""A structure containing a list of :class:`Node`, a list of :class:`Edge`,
and a specification describing the data.
Args:
nodes (``iterator``, :class:`Node`):
An iterator containing Vertices.
edges (``iterator``, :class:`Edge`):
An iterator containing Edges.
spec (:class:`GraphSpec`):
A spec describing the data.
"""
def __init__(self, nodes: Iterator[Node], edges: Iterator[Edge], spec: GraphSpec):
self.__spec = spec
self.__graph = self.create_graph(nodes, edges)
@property
def spec(self):
return self.__spec
@spec.setter
def spec(self, new_spec):
self.__spec = new_spec
@property
def directed(self):
return (
self.spec.directed
if self.spec.directed is not None
else self.__graph.is_directed()
)
def create_graph(self, nodes: Iterator[Node], edges: Iterator[Edge]):
if self.__spec.directed is None:
logger.debug(
"Trying to create a Graph without specifying directionality. Using default Directed!"
)
graph = nx.DiGraph()
elif self.__spec.directed:
graph = nx.DiGraph()
else:
graph = nx.Graph()
for node in nodes:
node.location = node.location.astype(self.spec.dtype)
vs = [(v.id, v.all) for v in nodes]
graph.add_nodes_from(vs)
graph.add_edges_from([(e.u, e.v, e.all) for e in edges])
return graph
@property
def nodes(self):
for node_id, node_attrs in self.__graph.nodes.items():
if "id" not in node_attrs:
node_attrs["id"] = node_id
v = Node.from_attrs(node_attrs)
if not np.issubdtype(v.location.dtype, self.spec.dtype):
raise Exception(
f"expected location to have dtype {self.spec.dtype} but it had {v.location.dtype}"
)
yield v
def num_vertices(self):
return self.__graph.number_of_nodes()
def num_edges(self):
return self.__graph.number_of_edges()
@property
def edges(self):
for (u, v), attrs in self.__graph.edges.items():
yield Edge(u, v, attrs)
def neighbors(self, node):
if self.directed:
for neighbor in self.__graph.successors(node.id):
yield Node.from_attrs(self.__graph.nodes[neighbor])
if self.directed:
for neighbor in self.__graph.predecessors(node.id):
yield Node.from_attrs(self.__graph.nodes[neighbor])
else:
for neighbor in self.__graph.neighbors(node.id):
yield Node.from_attrs(self.__graph.nodes[neighbor])
def __str__(self):
string = "Vertices:\n"
for node in self.nodes:
string += f"{node}\n"
string += "Edges:\n"
for edge in self.edges:
string += f"{edge}\n"
return string
def __repr__(self):
return str(self)
[docs]
def node(self, id: int):
"""
Get node with a specific id
"""
attrs = self.__graph.nodes[id]
return Node.from_attrs(attrs)
[docs]
def edge(self, id: tuple[int, int]):
"""
Get specific edge
"""
attrs = self.__graph.edges[id]
return Edge.from_attrs(attrs)
def contains(self, node_id: int):
return node_id in self.__graph.nodes
[docs]
def remove_node(self, node: Node, retain_connectivity=False):
"""
Remove a node.
retain_connectivity: preserve removed nodes neighboring edges.
Given graph: a->b->c, removing `b` without retain_connectivity
would leave us with two connected components, {'a'} and {'b'}.
removing 'b' with retain_connectivity flag set to True would
leave us with the graph: a->c, and only one connected component
{a, c}, thus preserving the connectivity of 'a' and 'c'
"""
if retain_connectivity:
predecessors = self.predecessors(node)
successors = self.successors(node)
for pred_id in predecessors:
for succ_id in successors:
if pred_id != succ_id:
self.add_edge(Edge(pred_id, succ_id))
self.__graph.remove_node(node.id)
[docs]
def add_node(self, node: Node):
"""
Adds a node to the graph.
If a node exists with the same id as the node you are adding,
its attributes will be overwritten.
"""
node.location = node.location.astype(self.spec.dtype)
self.__graph.add_node(node.id, **node.all)
[docs]
def remove_edge(self, edge: Edge):
"""
Remove an edge from the graph.
"""
self.__graph.remove_edge(edge.u, edge.v)
[docs]
def add_edge(self, edge: Edge):
"""
Adds an edge to the graph.
If an edge exists with the same u and v, its attributes
will be overwritten.
"""
self.__graph.add_edge(edge.u, edge.v, **edge.all)
def copy(self):
return deepcopy(self)
[docs]
def crop(self, roi: Roi):
"""
Will remove all nodes from self that are not contained in `roi` except for
"dangling" nodes. This means that if there are nodes A, B s.t. there
is an edge (A, B) and A is contained in `roi` but B is not, the edge (A, B)
is considered contained in the `roi` and thus node B will be kept as a
"dangling" node.
Note there is a helper function `trim` that will remove B and replace it with
a node at the intersection of the edge (A, B) and the bounding box of `roi`.
Args:
roi (:class:`Roi`):
ROI in world units to crop to.
"""
cropped = self.copy()
contained_nodes = set([v.id for v in cropped.nodes if roi.contains(v.location)])
all_contained_edges = set(
[
e
for e in cropped.edges
if e.u in contained_nodes or e.v in contained_nodes
]
)
contained_edge_nodes = set(list(itertools.chain(*all_contained_edges)))
all_nodes = contained_edge_nodes | contained_nodes
for node in list(cropped.nodes):
if node.id not in all_nodes:
cropped.remove_node(node)
for edge in list(cropped.edges):
if edge not in all_contained_edges:
cropped.remove_edge(edge)
cropped.spec.roi = roi
return cropped
def shift(self, offset):
for node in self.nodes:
node.location += offset
def new_graph(self):
if self.directed():
return nx.DiGraph()
else:
return nx.Graph()
[docs]
def trim(self, roi: Roi):
"""
Create a copy of self and replace "dangling" nodes with contained nodes.
A "dangling" node is defined by: Let A, B be nodes s.t. there exists an
edge (A, B) and A is contained in `roi` but B is not. Edge (A, B) is considered
contained, and thus B is kept as a "dangling" node.
"""
trimmed = self.copy()
contained_nodes = set([v.id for v in trimmed.nodes if roi.contains(v.location)])
all_contained_edges = set(
[
e
for e in trimmed.edges
if e.u in contained_nodes or e.v in contained_nodes
]
)
fully_contained_edges = set(
[
e
for e in all_contained_edges
if e.u in contained_nodes and e.v in contained_nodes
]
)
partially_contained_edges = all_contained_edges - fully_contained_edges
contained_edge_nodes = set(list(itertools.chain(*all_contained_edges)))
all_nodes = contained_edge_nodes | contained_nodes
next_node = 0 if len(all_nodes) == 0 else max(all_nodes) + 1
trimmed._handle_boundaries(
partially_contained_edges,
contained_nodes,
roi,
node_id=itertools.count(next_node),
)
for node in trimmed.nodes:
assert roi.contains(
node.location
), f"Failed to properly contain node {node.id} at {node.location}"
return trimmed
def _handle_boundaries(
self,
crossing_edges: Iterator[Edge],
contained_nodes: Set[int],
roi: Roi,
node_id: Iterator[int],
):
nodes_to_remove = set([])
for e in crossing_edges:
u, v = self.node(e.u), self.node(e.v)
u_in = u.id in contained_nodes
v_in, v_out = (u, v) if u_in else (v, u)
in_location, out_location = (v_in.location, v_out.location)
new_location = self._roi_intercept(in_location, out_location, roi)
if not all(np.isclose(new_location, in_location)):
# use deepcopy because modifying this node should not modify original
new_attrs = deepcopy(v_out.attrs)
new_attrs["id"] = next(node_id)
new_attrs["location"] = new_location
new_attrs["temporary"] = True
new_v = Node.from_attrs(new_attrs)
new_e = Edge(
u=v_in.id if u_in else new_v.id, v=new_v.id if u_in else v_in.id
)
self.add_node(new_v)
self.add_edge(new_e)
nodes_to_remove.add(v_out)
for node in nodes_to_remove:
self.remove_node(node)
def _roi_intercept(
self, inside: np.ndarray, outside: np.ndarray, bb: Roi
) -> np.ndarray:
"""
Given two points, one inside a bounding box and one outside,
get the intercept between the line and the bounding box.
"""
offset = outside - inside
distance = np.linalg.norm(offset)
assert not np.isclose(distance, 0), "Inside and Outside are the same location"
direction = offset / distance
# `offset` can be 0 on some but not all axes leaving a 0 in the denominator.
# `inside` can be on the bounding box, leaving a 0 in the numerator.
# `x/0` throws a division warning, `0/0` throws an invalid warning (both are fine here)
with np.errstate(divide="ignore", invalid="ignore"):
bb_x = np.asarray(
[
(np.asarray(bb.begin) - inside) / offset,
(np.asarray(bb.end) - inside) / offset,
],
dtype=self.spec.dtype,
)
with np.errstate(invalid="ignore"):
s = np.min(bb_x[np.logical_and((bb_x >= 0), (bb_x <= 1))])
new_location = inside + s * distance * direction
upper = np.array(bb.end, dtype=self.spec.dtype)
new_location = np.clip(
new_location, bb.begin, upper - upper * np.finfo(self.spec.dtype).eps
)
return new_location
[docs]
def merge(self, other, copy_from_self=False, copy=False):
"""
Merge this graph with another. The resulting graph will have the Roi
of the larger one.
This only works if one of the two graphs contains the other.
In this case, ``other`` will overwrite edges and nodes with the same
ID in ``self`` (unless ``copy_from_self`` is set to ``True``).
Vertices and edges in ``self`` that are contained in the Roi of ``other``
will be removed (vice versa for ``copy_from_self``)
A copy will only be made if necessary or ``copy`` is set to ``True``.
"""
# It is unclear how to merge points in all cases. Consider a 10x10 graph,
# you crop out a 5x5 area, do a shift augment, and attempt to merge.
# What does that mean? specs have changed. It should be a new key.
raise NotImplementedError("Merge function should not be used!")
self_roi = self.spec.roi
other_roi = other.spec.roi
assert self_roi.contains(other_roi) or other_roi.contains(
self_roi
), "Can not merge graphs that are not contained in each other."
# make sure self contains other
if not self_roi.contains(other_roi):
return other.merge(self, not copy_from_self, copy)
# edges and nodes in addition are guaranteed to be in merged
base = other if copy_from_self else self
addition = self if copy_from_self else other
if copy:
merged = base.copy()
else:
merged = base
for node in list(merged.nodes):
if merged.spec.roi.contains(node.location):
merged.remove_node(node)
for edge in list(merged.edges):
if merged.spec.roi.contains(
merged.node(edge.u)
) or merged.spec.roi.contains(merged.node(edge.v)):
merged.remove_edge(edge)
for node in addition.nodes:
merged.add_node(node)
for edge in addition.edges:
merged.add_edge(edge)
return merged
[docs]
def to_nx_graph(self):
"""
returns a pure networkx graph containing data from
this Graph.
"""
return deepcopy(self.__graph)
[docs]
@classmethod
def from_nx_graph(cls, graph, spec):
"""
Create a gunpowder graph from a networkx graph.
The network graph is expected to have a "location"
attribute for each node. If it is a subclass of a networkx
graph with extra functionality, this may not work.
"""
if spec.directed is None:
spec.directed = graph.is_directed()
g = cls([], [], spec)
g.__graph = graph
return g
[docs]
def relabel_connected_components(self):
"""
create a new attribute "component" for each node
in this Graph
"""
for i, wcc in enumerate(self.connected_components):
for node in wcc:
self.__graph.nodes[node]["component"] = i
@property
def connected_components(self):
if not self.directed:
return nx.connected_components(self.__graph)
else:
return nx.weakly_connected_components(self.__graph)
def in_degree(self):
return self.__graph.in_degree()
def successors(self, node):
if self.directed:
return self.__graph.successors(node.id)
else:
return self.__graph.neighbors(node.id)
def predecessors(self, node):
if self.directed:
return self.__graph.predecessors(node.id)
else:
return self.__graph.neighbors(node.id)
[docs]
class GraphKey(Freezable):
"""A key to identify graphs in requests, batches, and across
nodes.
Used as key in :class:`BatchRequest` and :class:`Batch` to retrieve specs
or graphs.
Args:
identifier (``string``):
A unique, human readable identifier for this graph key. Will be
used in log messages and to look up graphs in requests and batches.
Should be upper case (like ``CENTER_GRAPH``). The identifier is
unique: Two graph keys with the same identifier will refer to the
same graph.
"""
def __init__(self, identifier):
self.identifier = identifier
self.hash = hash(identifier)
self.freeze()
logger.debug("Registering graph type %s", self)
setattr(GraphKeys, self.identifier, self)
def __eq__(self, other):
return hasattr(other, "identifier") and self.identifier == other.identifier
def __hash__(self):
return self.hash
def __repr__(self):
return self.identifier
class GraphKeys:
"""Convenience access to all created :class:`GraphKey`s. A key generated
with::
centers = GraphKey('CENTER_GRAPH')
can be retrieved as::
GraphKeys.CENTER_GRAPH
"""
pass