diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 71c1f01..17321e1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,7 @@ jobs: conda install -c conda-forge python-graphblas scipy pandas \ pytest-cov pytest-randomly black flake8-comprehensions flake8-bugbear # matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet - pip install git+https://github.com/mriduls/networkx.git@nx-sparse --no-deps + pip install git+https://github.com/jim22k/networkx.git@nx-sparse --no-deps pip install -e . --no-deps - name: Style checks run: | diff --git a/graphblas_algorithms/classes/_utils.py b/graphblas_algorithms/classes/_utils.py index 4c60998..b1f9472 100644 --- a/graphblas_algorithms/classes/_utils.py +++ b/graphblas_algorithms/classes/_utils.py @@ -2,6 +2,7 @@ import numpy as np from graphblas import Matrix, Vector, binary from graphblas.core.matrix import TransposedMatrix +from graphblas.core.utils import ensure_type ################ # Classmethods # @@ -19,7 +20,8 @@ def from_networkx(cls, G, weight=None, dtype=None): def from_graphblas(cls, A, *, key_to_id=None): - # Does not copy! + # Does not copy if A is a Matrix! + A = ensure_type(A, Matrix) if A.nrows != A.ncols: raise ValueError(f"Adjacency matrix must be square; got {A.nrows} x {A.ncols}") rv = cls() diff --git a/graphblas_algorithms/classes/digraph.py b/graphblas_algorithms/classes/digraph.py index 69b85be..7b6890d 100644 --- a/graphblas_algorithms/classes/digraph.py +++ b/graphblas_algorithms/classes/digraph.py @@ -414,8 +414,10 @@ def to_directed_graph(G, weight=None, dtype=None): # We should do some sanity checks here to ensure we're returning a valid directed graph if isinstance(G, DiGraph): return G - if isinstance(G, Matrix): + try: return DiGraph.from_graphblas(G) + except TypeError: + pass try: import networkx as nx @@ -431,9 +433,11 @@ def to_directed_graph(G, weight=None, dtype=None): def to_graph(G, weight=None, dtype=None): if isinstance(G, (DiGraph, ga.Graph)): return G - if isinstance(G, Matrix): + try: # Should we check if it can be undirected? return DiGraph.from_graphblas(G) + except TypeError: + pass try: import networkx as nx diff --git a/graphblas_algorithms/classes/graph.py b/graphblas_algorithms/classes/graph.py index b679fa7..bc4ba24 100644 --- a/graphblas_algorithms/classes/graph.py +++ b/graphblas_algorithms/classes/graph.py @@ -152,8 +152,10 @@ def to_undirected_graph(G, weight=None, dtype=None): # We should do some sanity checks here to ensure we're returning a valid undirected graph if isinstance(G, Graph): return G - if isinstance(G, Matrix): + try: return Graph.from_graphblas(G) + except TypeError: + pass try: import networkx as nx diff --git a/graphblas_algorithms/interface.py b/graphblas_algorithms/interface.py index 6328767..4102d3b 100644 --- a/graphblas_algorithms/interface.py +++ b/graphblas_algorithms/interface.py @@ -68,7 +68,7 @@ class Dispatcher: is_triad = nxapi.triads.is_triad @staticmethod - def convert(graph, weight=None): + def convert_from_nx(graph, weight=None, *, name=None): import networkx as nx from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph @@ -83,6 +83,14 @@ def convert(graph, weight=None): return Graph.from_networkx(graph, weight=weight) raise TypeError(f"Unsupported type of graph: {type(graph)}") + @staticmethod + def convert_to_nx(obj, *, name=None): + from .classes import Graph + + if isinstance(obj, Graph): + obj = obj.to_networkx() + return obj + @staticmethod def on_start_tests(items): skip = [ diff --git a/graphblas_algorithms/nxapi/core.py b/graphblas_algorithms/nxapi/core.py index 52e0f9e..7da6e81 100644 --- a/graphblas_algorithms/nxapi/core.py +++ b/graphblas_algorithms/nxapi/core.py @@ -10,6 +10,4 @@ def k_truss(G, k): G = to_undirected_graph(G, dtype=bool) result = algorithms.k_truss(G, k) - # TODO: don't convert to networkx graph - # We want to be able to pass networkx tests, so we need to improve our graph objects - return result.to_networkx() + return result