Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/networkx/utils/backends.py: 26%
392 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-20 07:00 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-20 07:00 +0000
1"""
2Code to support various backends in a plugin dispatch architecture.
4Create a Dispatcher
5-------------------
7To be a valid backend, a package must register an entry_point
8of `networkx.backends` with a key pointing to the handler.
10For example::
12 entry_points={'networkx.backends': 'sparse = networkx_backend_sparse'}
14The backend must create a Graph-like object which contains an attribute
15``__networkx_backend__`` with a value of the entry point name.
17Continuing the example above::
19 class WrappedSparse:
20 __networkx_backend__ = "sparse"
21 ...
23When a dispatchable NetworkX algorithm encounters a Graph-like object
24with a ``__networkx_backend__`` attribute, it will look for the associated
25dispatch object in the entry_points, load it, and dispatch the work to it.
28Testing
29-------
30To assist in validating the backend algorithm implementations, if an
31environment variable ``NETWORKX_TEST_BACKEND`` is set to a registered
32backend key, the dispatch machinery will automatically convert regular
33networkx Graphs and DiGraphs to the backend equivalent by calling
34``<backend dispatcher>.convert_from_nx(G, edge_attrs=edge_attrs, name=name)``.
35Set ``NETWORKX_FALLBACK_TO_NX`` environment variable to have tests
36use networkx graphs for algorithms not implemented by the backend.
38The arguments to ``convert_from_nx`` are:
40- ``G`` : networkx Graph
41- ``edge_attrs`` : dict, optional
42 Dict that maps edge attributes to default values if missing in ``G``.
43 If None, then no edge attributes will be converted and default may be 1.
44- ``node_attrs``: dict, optional
45 Dict that maps node attribute to default values if missing in ``G``.
46 If None, then no node attributes will be converted.
47- ``preserve_edge_attrs`` : bool
48 Whether to preserve all edge attributes.
49- ``preserve_node_attrs`` : bool
50 Whether to preserve all node attributes.
51- ``preserve_graph_attrs`` : bool
52 Whether to preserve all graph attributes.
53- ``preserve_all_attrs`` : bool
54 Whether to preserve all graph, node, and edge attributes.
55- ``name`` : str
56 The name of the algorithm.
57- ``graph_name`` : str
58 The name of the graph argument being converted.
60The converted object is then passed to the backend implementation of
61the algorithm. The result is then passed to
62``<backend dispatcher>.convert_to_nx(result, name=name)`` to convert back
63to a form expected by the NetworkX tests.
65By defining ``convert_from_nx`` and ``convert_to_nx`` methods and setting
66the environment variable, NetworkX will automatically route tests on
67dispatchable algorithms to the backend, allowing the full networkx test
68suite to be run against the backend implementation.
70Example pytest invocation::
72 NETWORKX_TEST_BACKEND=sparse pytest --pyargs networkx
74Dispatchable algorithms which are not implemented by the backend
75will cause a ``pytest.xfail()``, giving some indication that not all
76tests are working, while avoiding causing an explicit failure.
78If a backend only partially implements some algorithms, it can define
79a ``can_run(name, args, kwargs)`` function that returns True or False
80indicating whether it can run the algorithm with the given arguments.
82A special ``on_start_tests(items)`` function may be defined by the backend.
83It will be called with the list of NetworkX tests discovered. Each item
84is a test object that can be marked as xfail if the backend does not support
85the test using `item.add_marker(pytest.mark.xfail(reason=...))`.
86"""
87import inspect
88import os
89import sys
90import warnings
91from functools import partial
92from importlib.metadata import entry_points
94from ..exception import NetworkXNotImplemented
96__all__ = ["_dispatch"]
99def _get_backends(group, *, load_and_call=False):
100 if sys.version_info < (3, 10):
101 eps = entry_points()
102 if group not in eps:
103 return {}
104 items = eps[group]
105 else:
106 items = entry_points(group=group)
107 rv = {}
108 for ep in items:
109 if ep.name in rv:
110 warnings.warn(
111 f"networkx backend defined more than once: {ep.name}",
112 RuntimeWarning,
113 stacklevel=2,
114 )
115 elif load_and_call:
116 try:
117 rv[ep.name] = ep.load()()
118 except Exception as exc:
119 warnings.warn(
120 f"Error encountered when loading info for backend {ep.name}: {exc}",
121 RuntimeWarning,
122 stacklevel=2,
123 )
124 else:
125 rv[ep.name] = ep
126 # nx-loopback backend is only available when testing (added in conftest.py)
127 rv.pop("nx-loopback", None)
128 return rv
131# Rename "plugin" to "backend", and give backends a release cycle to update.
132backends = _get_backends("networkx.plugins")
133backend_info = _get_backends("networkx.plugin_info", load_and_call=True)
135backends.update(_get_backends("networkx.backends"))
136backend_info.update(_get_backends("networkx.backend_info", load_and_call=True))
138# Load and cache backends on-demand
139_loaded_backends = {} # type: ignore[var-annotated]
142def _load_backend(backend_name):
143 if backend_name in _loaded_backends:
144 return _loaded_backends[backend_name]
145 rv = _loaded_backends[backend_name] = backends[backend_name].load()
146 return rv
149_registered_algorithms = {}
152class _dispatch:
153 """Dispatches to a backend algorithm based on input graph types.
155 Parameters
156 ----------
157 func : function
159 name : str, optional
160 The name of the algorithm to use for dispatching. If not provided,
161 the name of ``func`` will be used. ``name`` is useful to avoid name
162 conflicts, as all dispatched algorithms live in a single namespace.
164 graphs : str or dict or None, default "G"
165 If a string, the parameter name of the graph, which must be the first
166 argument of the wrapped function. If more than one graph is required
167 for the algorithm (or if the graph is not the first argument), provide
168 a dict of parameter name to argument position for each graph argument.
169 For example, ``@_dispatch(graphs={"G": 0, "auxiliary?": 4})``
170 indicates the 0th parameter ``G`` of the function is a required graph,
171 and the 4th parameter ``auxiliary`` is an optional graph.
172 To indicate an argument is a list of graphs, do e.g. ``"[graphs]"``.
173 Use ``graphs=None`` if *no* arguments are NetworkX graphs such as for
174 graph generators, readers, and conversion functions.
176 edge_attrs : str or dict, optional
177 ``edge_attrs`` holds information about edge attribute arguments
178 and default values for those edge attributes.
179 If a string, ``edge_attrs`` holds the function argument name that
180 indicates a single edge attribute to include in the converted graph.
181 The default value for this attribute is 1. To indicate that an argument
182 is a list of attributes (all with default value 1), use e.g. ``"[attrs]"``.
183 If a dict, ``edge_attrs`` holds a dict keyed by argument names, with
184 values that are either the default value or, if a string, the argument
185 name that indicates the default value.
187 node_attrs : str or dict, optional
188 Like ``edge_attrs``, but for node attributes.
190 preserve_edge_attrs : bool or str or dict, optional
191 For bool, whether to preserve all edge attributes.
192 For str, the parameter name that may indicate (with ``True`` or a
193 callable argument) whether all edge attributes should be preserved
194 when converting.
195 For dict of ``{graph_name: {attr: default}}``, indicate pre-determined
196 edge attributes (and defaults) to preserve for input graphs.
198 preserve_node_attrs : bool or str or dict, optional
199 Like ``preserve_edge_attrs``, but for node attributes.
201 preserve_graph_attrs : bool or set
202 For bool, whether to preserve all graph attributes.
203 For set, which input graph arguments to preserve graph attributes.
205 preserve_all_attrs : bool
206 Whether to preserve all edge, node and graph attributes.
207 This overrides all the other preserve_*_attrs.
209 """
211 # Allow any of the following decorator forms:
212 # - @_dispatch
213 # - @_dispatch()
214 # - @_dispatch(name="override_name")
215 # - @_dispatch(graphs="graph")
216 # - @_dispatch(edge_attrs="weight")
217 # - @_dispatch(graphs={"G": 0, "H": 1}, edge_attrs={"weight": "default"})
219 # These class attributes are currently used to allow backends to run networkx tests.
220 # For example: `PYTHONPATH=. pytest --backend graphblas --fallback-to-nx`
221 # Future work: add configuration to control these
222 _is_testing = False
223 _fallback_to_nx = (
224 os.environ.get("NETWORKX_FALLBACK_TO_NX", "true").strip().lower() == "true"
225 )
226 _automatic_backends = [
227 x.strip()
228 for x in os.environ.get("NETWORKX_AUTOMATIC_BACKENDS", "").split(",")
229 if x.strip()
230 ]
232 def __new__(
233 cls,
234 func=None,
235 *,
236 name=None,
237 graphs="G",
238 edge_attrs=None,
239 node_attrs=None,
240 preserve_edge_attrs=False,
241 preserve_node_attrs=False,
242 preserve_graph_attrs=False,
243 preserve_all_attrs=False,
244 ):
245 if func is None:
246 return partial(
247 _dispatch,
248 name=name,
249 graphs=graphs,
250 edge_attrs=edge_attrs,
251 node_attrs=node_attrs,
252 preserve_edge_attrs=preserve_edge_attrs,
253 preserve_node_attrs=preserve_node_attrs,
254 preserve_graph_attrs=preserve_graph_attrs,
255 preserve_all_attrs=preserve_all_attrs,
256 )
257 if isinstance(func, str):
258 raise TypeError("'name' and 'graphs' must be passed by keyword") from None
259 # If name not provided, use the name of the function
260 if name is None:
261 name = func.__name__
263 self = object.__new__(cls)
265 # standard function-wrapping stuff
266 # __annotations__ not used
267 self.__name__ = func.__name__
268 # self.__doc__ = func.__doc__ # __doc__ handled as cached property
269 self.__defaults__ = func.__defaults__
270 # We "magically" add `backend=` keyword argument to allow backend to be specified
271 if func.__kwdefaults__:
272 self.__kwdefaults__ = {**func.__kwdefaults__, "backend": None}
273 else:
274 self.__kwdefaults__ = {"backend": None}
275 self.__module__ = func.__module__
276 self.__qualname__ = func.__qualname__
277 self.__dict__.update(func.__dict__)
278 self.__wrapped__ = func
280 # Supplement docstring with backend info; compute and cache when needed
281 self._orig_doc = func.__doc__
282 self._cached_doc = None
284 self.orig_func = func
285 self.name = name
286 self.edge_attrs = edge_attrs
287 self.node_attrs = node_attrs
288 self.preserve_edge_attrs = preserve_edge_attrs or preserve_all_attrs
289 self.preserve_node_attrs = preserve_node_attrs or preserve_all_attrs
290 self.preserve_graph_attrs = preserve_graph_attrs or preserve_all_attrs
292 if edge_attrs is not None and not isinstance(edge_attrs, (str, dict)):
293 raise TypeError(
294 f"Bad type for edge_attrs: {type(edge_attrs)}. Expected str or dict."
295 ) from None
296 if node_attrs is not None and not isinstance(node_attrs, (str, dict)):
297 raise TypeError(
298 f"Bad type for node_attrs: {type(node_attrs)}. Expected str or dict."
299 ) from None
300 if not isinstance(self.preserve_edge_attrs, (bool, str, dict)):
301 raise TypeError(
302 f"Bad type for preserve_edge_attrs: {type(self.preserve_edge_attrs)}."
303 " Expected bool, str, or dict."
304 ) from None
305 if not isinstance(self.preserve_node_attrs, (bool, str, dict)):
306 raise TypeError(
307 f"Bad type for preserve_node_attrs: {type(self.preserve_node_attrs)}."
308 " Expected bool, str, or dict."
309 ) from None
310 if not isinstance(self.preserve_graph_attrs, (bool, set)):
311 raise TypeError(
312 f"Bad type for preserve_graph_attrs: {type(self.preserve_graph_attrs)}."
313 " Expected bool or set."
314 ) from None
316 if isinstance(graphs, str):
317 graphs = {graphs: 0}
318 elif graphs is None:
319 pass
320 elif not isinstance(graphs, dict):
321 raise TypeError(
322 f"Bad type for graphs: {type(graphs)}. Expected str or dict."
323 ) from None
324 elif len(graphs) == 0:
325 raise KeyError("'graphs' must contain at least one variable name") from None
327 # This dict comprehension is complicated for better performance; equivalent shown below.
328 self.optional_graphs = set()
329 self.list_graphs = set()
330 if graphs is None:
331 self.graphs = {}
332 else:
333 self.graphs = {
334 self.optional_graphs.add(val := k[:-1]) or val
335 if (last := k[-1]) == "?"
336 else self.list_graphs.add(val := k[1:-1]) or val
337 if last == "]"
338 else k: v
339 for k, v in graphs.items()
340 }
341 # The above is equivalent to:
342 # self.optional_graphs = {k[:-1] for k in graphs if k[-1] == "?"}
343 # self.list_graphs = {k[1:-1] for k in graphs if k[-1] == "]"}
344 # self.graphs = {k[:-1] if k[-1] == "?" else k: v for k, v in graphs.items()}
346 # Compute and cache the signature on-demand
347 self._sig = None
349 # Which backends implement this function?
350 self.backends = {
351 backend
352 for backend, info in backend_info.items()
353 if "functions" in info and name in info["functions"]
354 }
356 if name in _registered_algorithms:
357 raise KeyError(
358 f"Algorithm already exists in dispatch registry: {name}"
359 ) from None
360 _registered_algorithms[name] = self
361 return self
363 @property
364 def __doc__(self):
365 if (rv := self._cached_doc) is not None:
366 return rv
367 rv = self._cached_doc = self._make_doc()
368 return rv
370 @__doc__.setter
371 def __doc__(self, val):
372 self._orig_doc = val
373 self._cached_doc = None
375 @property
376 def __signature__(self):
377 if self._sig is None:
378 sig = inspect.signature(self.orig_func)
379 # `backend` is now a reserved argument used by dispatching.
380 # assert "backend" not in sig.parameters
381 if not any(
382 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
383 ):
384 sig = sig.replace(
385 parameters=[
386 *sig.parameters.values(),
387 inspect.Parameter(
388 "backend", inspect.Parameter.KEYWORD_ONLY, default=None
389 ),
390 inspect.Parameter(
391 "backend_kwargs", inspect.Parameter.VAR_KEYWORD
392 ),
393 ]
394 )
395 else:
396 *parameters, var_keyword = sig.parameters.values()
397 sig = sig.replace(
398 parameters=[
399 *parameters,
400 inspect.Parameter(
401 "backend", inspect.Parameter.KEYWORD_ONLY, default=None
402 ),
403 var_keyword,
404 ]
405 )
406 self._sig = sig
407 return self._sig
409 def __call__(self, /, *args, backend=None, **kwargs):
410 if not backends:
411 # Fast path if no backends are installed
412 return self.orig_func(*args, **kwargs)
414 # Use `backend_name` in this function instead of `backend`
415 backend_name = backend
416 if backend_name is not None and backend_name not in backends:
417 raise ImportError(f"Unable to load backend: {backend_name}")
419 graphs_resolved = {}
420 for gname, pos in self.graphs.items():
421 if pos < len(args):
422 if gname in kwargs:
423 raise TypeError(f"{self.name}() got multiple values for {gname!r}")
424 val = args[pos]
425 elif gname in kwargs:
426 val = kwargs[gname]
427 elif gname not in self.optional_graphs:
428 raise TypeError(
429 f"{self.name}() missing required graph argument: {gname}"
430 )
431 else:
432 continue
433 if val is None:
434 if gname not in self.optional_graphs:
435 raise TypeError(
436 f"{self.name}() required graph argument {gname!r} is None; must be a graph"
437 )
438 else:
439 graphs_resolved[gname] = val
441 # Alternative to the above that does not check duplicated args or missing required graphs.
442 # graphs_resolved = {
443 # val
444 # for gname, pos in self.graphs.items()
445 # if (val := args[pos] if pos < len(args) else kwargs.get(gname)) is not None
446 # }
448 if self._is_testing and self._automatic_backends and backend_name is None:
449 # Special path if we are running networkx tests with a backend.
450 return self._convert_and_call_for_tests(
451 self._automatic_backends[0],
452 args,
453 kwargs,
454 fallback_to_nx=self._fallback_to_nx,
455 )
457 # Check if any graph comes from a backend
458 if self.list_graphs:
459 # Make sure we don't lose values by consuming an iterator
460 args = list(args)
461 for gname in self.list_graphs & graphs_resolved.keys():
462 val = list(graphs_resolved[gname])
463 graphs_resolved[gname] = val
464 if gname in kwargs:
465 kwargs[gname] = val
466 else:
467 args[self.graphs[gname]] = val
469 has_backends = any(
470 hasattr(g, "__networkx_backend__") or hasattr(g, "__networkx_plugin__")
471 if gname not in self.list_graphs
472 else any(
473 hasattr(g2, "__networkx_backend__")
474 or hasattr(g2, "__networkx_plugin__")
475 for g2 in g
476 )
477 for gname, g in graphs_resolved.items()
478 )
479 if has_backends:
480 graph_backend_names = {
481 getattr(
482 g,
483 "__networkx_backend__",
484 getattr(g, "__networkx_plugin__", "networkx"),
485 )
486 for gname, g in graphs_resolved.items()
487 if gname not in self.list_graphs
488 }
489 for gname in self.list_graphs & graphs_resolved.keys():
490 graph_backend_names.update(
491 getattr(
492 g,
493 "__networkx_backend__",
494 getattr(g, "__networkx_plugin__", "networkx"),
495 )
496 for g in graphs_resolved[gname]
497 )
498 else:
499 has_backends = any(
500 hasattr(g, "__networkx_backend__") or hasattr(g, "__networkx_plugin__")
501 for g in graphs_resolved.values()
502 )
503 if has_backends:
504 graph_backend_names = {
505 getattr(
506 g,
507 "__networkx_backend__",
508 getattr(g, "__networkx_plugin__", "networkx"),
509 )
510 for g in graphs_resolved.values()
511 }
512 if has_backends:
513 # Dispatchable graphs found! Dispatch to backend function.
514 # We don't handle calls with different backend graphs yet,
515 # but we may be able to convert additional networkx graphs.
516 backend_names = graph_backend_names - {"networkx"}
517 if len(backend_names) != 1:
518 # Future work: convert between backends and run if multiple backends found
519 raise TypeError(
520 f"{self.name}() graphs must all be from the same backend, found {backend_names}"
521 )
522 [graph_backend_name] = backend_names
523 if backend_name is not None and backend_name != graph_backend_name:
524 # Future work: convert between backends to `backend_name` backend
525 raise TypeError(
526 f"{self.name}() is unable to convert graph from backend {graph_backend_name!r} "
527 f"to the specified backend {backend_name!r}."
528 )
529 if graph_backend_name not in backends:
530 raise ImportError(f"Unable to load backend: {graph_backend_name}")
531 if (
532 "networkx" in graph_backend_names
533 and graph_backend_name not in self._automatic_backends
534 ):
535 # Not configured to convert networkx graphs to this backend
536 raise TypeError(
537 f"Unable to convert inputs and run {self.name}. "
538 f"{self.name}() has networkx and {graph_backend_name} graphs, but NetworkX is not "
539 f"configured to automatically convert graphs from networkx to {graph_backend_name}."
540 )
541 backend = _load_backend(graph_backend_name)
542 if hasattr(backend, self.name):
543 if "networkx" in graph_backend_names:
544 # We need to convert networkx graphs to backend graphs
545 return self._convert_and_call(
546 graph_backend_name,
547 args,
548 kwargs,
549 fallback_to_nx=self._fallback_to_nx,
550 )
551 # All graphs are backend graphs--no need to convert!
552 return getattr(backend, self.name)(*args, **kwargs)
553 # Future work: try to convert and run with other backends in self._automatic_backends
554 raise NetworkXNotImplemented(
555 f"'{self.name}' not implemented by {graph_backend_name}"
556 )
558 # If backend was explicitly given by the user, so we need to use it no matter what
559 if backend_name is not None:
560 return self._convert_and_call(
561 backend_name, args, kwargs, fallback_to_nx=False
562 )
564 # Only networkx graphs; try to convert and run with a backend with automatic
565 # conversion, but don't do this by default for graph generators or loaders.
566 if self.graphs:
567 for backend_name in self._automatic_backends:
568 if self._can_backend_run(backend_name, *args, **kwargs):
569 return self._convert_and_call(
570 backend_name,
571 args,
572 kwargs,
573 fallback_to_nx=self._fallback_to_nx,
574 )
575 # Default: run with networkx on networkx inputs
576 return self.orig_func(*args, **kwargs)
578 def _can_backend_run(self, backend_name, /, *args, **kwargs):
579 """Can the specified backend run this algorithms with these arguments?"""
580 backend = _load_backend(backend_name)
581 return hasattr(backend, self.name) and (
582 not hasattr(backend, "can_run") or backend.can_run(self.name, args, kwargs)
583 )
585 def _convert_arguments(self, backend_name, args, kwargs):
586 """Convert graph arguments to the specified backend.
588 Returns
589 -------
590 args tuple and kwargs dict
591 """
592 bound = self.__signature__.bind(*args, **kwargs)
593 bound.apply_defaults()
594 if not self.graphs:
595 bound_kwargs = bound.kwargs
596 del bound_kwargs["backend"]
597 return bound.args, bound_kwargs
598 # Convert graphs into backend graph-like object
599 # Include the edge and/or node labels if provided to the algorithm
600 preserve_edge_attrs = self.preserve_edge_attrs
601 edge_attrs = self.edge_attrs
602 if preserve_edge_attrs is False:
603 # e.g. `preserve_edge_attrs=False`
604 pass
605 elif preserve_edge_attrs is True:
606 # e.g. `preserve_edge_attrs=True`
607 edge_attrs = None
608 elif isinstance(preserve_edge_attrs, str):
609 if bound.arguments[preserve_edge_attrs] is True or callable(
610 bound.arguments[preserve_edge_attrs]
611 ):
612 # e.g. `preserve_edge_attrs="attr"` and `func(attr=True)`
613 # e.g. `preserve_edge_attrs="attr"` and `func(attr=myfunc)`
614 preserve_edge_attrs = True
615 edge_attrs = None
616 elif bound.arguments[preserve_edge_attrs] is False and (
617 isinstance(edge_attrs, str)
618 and edge_attrs == preserve_edge_attrs
619 or isinstance(edge_attrs, dict)
620 and preserve_edge_attrs in edge_attrs
621 ):
622 # e.g. `preserve_edge_attrs="attr"` and `func(attr=False)`
623 # Treat `False` argument as meaning "preserve_edge_data=False"
624 # and not `False` as the edge attribute to use.
625 preserve_edge_attrs = False
626 edge_attrs = None
627 else:
628 # e.g. `preserve_edge_attrs="attr"` and `func(attr="weight")`
629 preserve_edge_attrs = False
630 # Else: e.g. `preserve_edge_attrs={"G": {"weight": 1}}`
632 if edge_attrs is None:
633 # May have been set to None above b/c all attributes are preserved
634 pass
635 elif isinstance(edge_attrs, str):
636 if edge_attrs[0] == "[":
637 # e.g. `edge_attrs="[edge_attributes]"` (argument of list of attributes)
638 # e.g. `func(edge_attributes=["foo", "bar"])`
639 edge_attrs = {
640 edge_attr: 1 for edge_attr in bound.arguments[edge_attrs[1:-1]]
641 }
642 elif callable(bound.arguments[edge_attrs]):
643 # e.g. `edge_attrs="weight"` and `func(weight=myfunc)`
644 preserve_edge_attrs = True
645 edge_attrs = None
646 elif bound.arguments[edge_attrs] is not None:
647 # e.g. `edge_attrs="weight"` and `func(weight="foo")` (default of 1)
648 edge_attrs = {bound.arguments[edge_attrs]: 1}
649 elif self.name == "to_numpy_array" and hasattr(
650 bound.arguments["dtype"], "names"
651 ):
652 # Custom handling: attributes may be obtained from `dtype`
653 edge_attrs = {
654 edge_attr: 1 for edge_attr in bound.arguments["dtype"].names
655 }
656 else:
657 # e.g. `edge_attrs="weight"` and `func(weight=None)`
658 edge_attrs = None
659 else:
660 # e.g. `edge_attrs={"attr": "default"}` and `func(attr="foo", default=7)`
661 # e.g. `edge_attrs={"attr": 0}` and `func(attr="foo")`
662 edge_attrs = {
663 edge_attr: bound.arguments.get(val, 1) if isinstance(val, str) else val
664 for key, val in edge_attrs.items()
665 if (edge_attr := bound.arguments[key]) is not None
666 }
668 preserve_node_attrs = self.preserve_node_attrs
669 node_attrs = self.node_attrs
670 if preserve_node_attrs is False:
671 # e.g. `preserve_node_attrs=False`
672 pass
673 elif preserve_node_attrs is True:
674 # e.g. `preserve_node_attrs=True`
675 node_attrs = None
676 elif isinstance(preserve_node_attrs, str):
677 if bound.arguments[preserve_node_attrs] is True or callable(
678 bound.arguments[preserve_node_attrs]
679 ):
680 # e.g. `preserve_node_attrs="attr"` and `func(attr=True)`
681 # e.g. `preserve_node_attrs="attr"` and `func(attr=myfunc)`
682 preserve_node_attrs = True
683 node_attrs = None
684 elif bound.arguments[preserve_node_attrs] is False and (
685 isinstance(node_attrs, str)
686 and node_attrs == preserve_node_attrs
687 or isinstance(node_attrs, dict)
688 and preserve_node_attrs in node_attrs
689 ):
690 # e.g. `preserve_node_attrs="attr"` and `func(attr=False)`
691 # Treat `False` argument as meaning "preserve_node_data=False"
692 # and not `False` as the node attribute to use. Is this used?
693 preserve_node_attrs = False
694 node_attrs = None
695 else:
696 # e.g. `preserve_node_attrs="attr"` and `func(attr="weight")`
697 preserve_node_attrs = False
698 # Else: e.g. `preserve_node_attrs={"G": {"pos": None}}`
700 if node_attrs is None:
701 # May have been set to None above b/c all attributes are preserved
702 pass
703 elif isinstance(node_attrs, str):
704 if node_attrs[0] == "[":
705 # e.g. `node_attrs="[node_attributes]"` (argument of list of attributes)
706 # e.g. `func(node_attributes=["foo", "bar"])`
707 node_attrs = {
708 node_attr: None for node_attr in bound.arguments[node_attrs[1:-1]]
709 }
710 elif callable(bound.arguments[node_attrs]):
711 # e.g. `node_attrs="weight"` and `func(weight=myfunc)`
712 preserve_node_attrs = True
713 node_attrs = None
714 elif bound.arguments[node_attrs] is not None:
715 # e.g. `node_attrs="weight"` and `func(weight="foo")`
716 node_attrs = {bound.arguments[node_attrs]: None}
717 else:
718 # e.g. `node_attrs="weight"` and `func(weight=None)`
719 node_attrs = None
720 else:
721 # e.g. `node_attrs={"attr": "default"}` and `func(attr="foo", default=7)`
722 # e.g. `node_attrs={"attr": 0}` and `func(attr="foo")`
723 node_attrs = {
724 node_attr: bound.arguments.get(val) if isinstance(val, str) else val
725 for key, val in node_attrs.items()
726 if (node_attr := bound.arguments[key]) is not None
727 }
729 preserve_graph_attrs = self.preserve_graph_attrs
731 # It should be safe to assume that we either have networkx graphs or backend graphs.
732 # Future work: allow conversions between backends.
733 backend = _load_backend(backend_name)
734 for gname in self.graphs:
735 if gname in self.list_graphs:
736 bound.arguments[gname] = [
737 backend.convert_from_nx(
738 g,
739 edge_attrs=edge_attrs,
740 node_attrs=node_attrs,
741 preserve_edge_attrs=preserve_edge_attrs,
742 preserve_node_attrs=preserve_node_attrs,
743 preserve_graph_attrs=preserve_graph_attrs,
744 name=self.name,
745 graph_name=gname,
746 )
747 if getattr(
748 g,
749 "__networkx_backend__",
750 getattr(g, "__networkx_plugin__", "networkx"),
751 )
752 == "networkx"
753 else g
754 for g in bound.arguments[gname]
755 ]
756 else:
757 graph = bound.arguments[gname]
758 if graph is None:
759 if gname in self.optional_graphs:
760 continue
761 raise TypeError(
762 f"Missing required graph argument `{gname}` in {self.name} function"
763 )
764 if isinstance(preserve_edge_attrs, dict):
765 preserve_edges = False
766 edges = preserve_edge_attrs.get(gname, edge_attrs)
767 else:
768 preserve_edges = preserve_edge_attrs
769 edges = edge_attrs
770 if isinstance(preserve_node_attrs, dict):
771 preserve_nodes = False
772 nodes = preserve_node_attrs.get(gname, node_attrs)
773 else:
774 preserve_nodes = preserve_node_attrs
775 nodes = node_attrs
776 if isinstance(preserve_graph_attrs, set):
777 preserve_graph = gname in preserve_graph_attrs
778 else:
779 preserve_graph = preserve_graph_attrs
780 if (
781 getattr(
782 graph,
783 "__networkx_backend__",
784 getattr(graph, "__networkx_plugin__", "networkx"),
785 )
786 == "networkx"
787 ):
788 bound.arguments[gname] = backend.convert_from_nx(
789 graph,
790 edge_attrs=edges,
791 node_attrs=nodes,
792 preserve_edge_attrs=preserve_edges,
793 preserve_node_attrs=preserve_nodes,
794 preserve_graph_attrs=preserve_graph,
795 name=self.name,
796 graph_name=gname,
797 )
798 bound_kwargs = bound.kwargs
799 del bound_kwargs["backend"]
800 return bound.args, bound_kwargs
802 def _convert_and_call(self, backend_name, args, kwargs, *, fallback_to_nx=False):
803 """Call this dispatchable function with a backend, converting graphs if necessary."""
804 backend = _load_backend(backend_name)
805 if not self._can_backend_run(backend_name, *args, **kwargs):
806 if fallback_to_nx:
807 return self.orig_func(*args, **kwargs)
808 msg = f"'{self.name}' not implemented by {backend_name}"
809 if hasattr(backend, self.name):
810 msg += " with the given arguments"
811 raise RuntimeError(msg)
813 try:
814 converted_args, converted_kwargs = self._convert_arguments(
815 backend_name, args, kwargs
816 )
817 result = getattr(backend, self.name)(*converted_args, **converted_kwargs)
818 except (NotImplementedError, NetworkXNotImplemented) as exc:
819 if fallback_to_nx:
820 return self.orig_func(*args, **kwargs)
821 raise
823 return result
825 def _convert_and_call_for_tests(
826 self, backend_name, args, kwargs, *, fallback_to_nx=False
827 ):
828 """Call this dispatchable function with a backend; for use with testing."""
829 backend = _load_backend(backend_name)
830 if not self._can_backend_run(backend_name, *args, **kwargs):
831 if fallback_to_nx or not self.graphs:
832 return self.orig_func(*args, **kwargs)
834 import pytest
836 msg = f"'{self.name}' not implemented by {backend_name}"
837 if hasattr(backend, self.name):
838 msg += " with the given arguments"
839 pytest.xfail(msg)
841 try:
842 converted_args, converted_kwargs = self._convert_arguments(
843 backend_name, args, kwargs
844 )
845 result = getattr(backend, self.name)(*converted_args, **converted_kwargs)
846 except (NotImplementedError, NetworkXNotImplemented) as exc:
847 if fallback_to_nx:
848 return self.orig_func(*args, **kwargs)
849 import pytest
851 pytest.xfail(
852 exc.args[0] if exc.args else f"{self.name} raised {type(exc).__name__}"
853 )
855 if self.name in {
856 "edmonds_karp_core",
857 "barycenter",
858 "contracted_nodes",
859 "stochastic_graph",
860 "relabel_nodes",
861 }:
862 # Special-case algorithms that mutate input graphs
863 bound = self.__signature__.bind(*converted_args, **converted_kwargs)
864 bound.apply_defaults()
865 bound2 = self.__signature__.bind(*args, **kwargs)
866 bound2.apply_defaults()
867 if self.name == "edmonds_karp_core":
868 R1 = backend.convert_to_nx(bound.arguments["R"])
869 R2 = bound2.arguments["R"]
870 for k, v in R1.edges.items():
871 R2.edges[k]["flow"] = v["flow"]
872 elif self.name == "barycenter" and bound.arguments["attr"] is not None:
873 G1 = backend.convert_to_nx(bound.arguments["G"])
874 G2 = bound2.arguments["G"]
875 attr = bound.arguments["attr"]
876 for k, v in G1.nodes.items():
877 G2.nodes[k][attr] = v[attr]
878 elif self.name == "contracted_nodes" and not bound.arguments["copy"]:
879 # Edges and nodes changed; node "contraction" and edge "weight" attrs
880 G1 = backend.convert_to_nx(bound.arguments["G"])
881 G2 = bound2.arguments["G"]
882 G2.__dict__.update(G1.__dict__)
883 elif self.name == "stochastic_graph" and not bound.arguments["copy"]:
884 G1 = backend.convert_to_nx(bound.arguments["G"])
885 G2 = bound2.arguments["G"]
886 for k, v in G1.edges.items():
887 G2.edges[k]["weight"] = v["weight"]
888 elif self.name == "relabel_nodes" and not bound.arguments["copy"]:
889 G1 = backend.convert_to_nx(bound.arguments["G"])
890 G2 = bound2.arguments["G"]
891 if G1 is G2:
892 return G2
893 G2._node.clear()
894 G2._node.update(G1._node)
895 G2._adj.clear()
896 G2._adj.update(G1._adj)
897 if hasattr(G1, "_pred") and hasattr(G2, "_pred"):
898 G2._pred.clear()
899 G2._pred.update(G1._pred)
900 if hasattr(G1, "_succ") and hasattr(G2, "_succ"):
901 G2._succ.clear()
902 G2._succ.update(G1._succ)
903 return G2
905 return backend.convert_to_nx(result, name=self.name)
907 def _make_doc(self):
908 if not self.backends:
909 return self._orig_doc
910 lines = [
911 "Backends",
912 "--------",
913 ]
914 for backend in sorted(self.backends):
915 info = backend_info[backend]
916 if "short_summary" in info:
917 lines.append(f"{backend} : {info['short_summary']}")
918 else:
919 lines.append(backend)
920 if "functions" not in info or self.name not in info["functions"]:
921 lines.append("")
922 continue
924 func_info = info["functions"][self.name]
925 if "extra_docstring" in func_info:
926 lines.extend(
927 f" {line}" if line else line
928 for line in func_info["extra_docstring"].split("\n")
929 )
930 add_gap = True
931 else:
932 add_gap = False
933 if "extra_parameters" in func_info:
934 if add_gap:
935 lines.append("")
936 lines.append(" Extra parameters:")
937 extra_parameters = func_info["extra_parameters"]
938 for param in sorted(extra_parameters):
939 lines.append(f" {param}")
940 if desc := extra_parameters[param]:
941 lines.append(f" {desc}")
942 lines.append("")
943 else:
944 lines.append("")
946 lines.pop() # Remove last empty line
947 to_add = "\n ".join(lines)
948 return f"{self._orig_doc.rstrip()}\n\n {to_add}"
950 def __reduce__(self):
951 """Allow this object to be serialized with pickle.
953 This uses the global registry `_registered_algorithms` to deserialize.
954 """
955 return _restore_dispatch, (self.name,)
958def _restore_dispatch(name):
959 return _registered_algorithms[name]
962if os.environ.get("_NETWORKX_BUILDING_DOCS_"):
963 # When building docs with Sphinx, use the original function with the
964 # dispatched __doc__, b/c Sphinx renders normal Python functions better.
965 # This doesn't show e.g. `*, backend=None, **backend_kwargs` in the
966 # signatures, which is probably okay. It does allow the docstring to be
967 # updated based on the installed backends.
968 _orig_dispatch = _dispatch
970 def _dispatch(func=None, **kwargs): # type: ignore[no-redef]
971 if func is None:
972 return partial(_dispatch, **kwargs)
973 dispatched_func = _orig_dispatch(func, **kwargs)
974 func.__doc__ = dispatched_func.__doc__
975 return func