Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
features:
- |
Added a ``node_list`` argument to :func:`~rustworkx.adjacency_matrix`,
:func:`~rustworkx.graph_adjacency_matrix`, and
:func:`~rustworkx.digraph_adjacency_matrix`. The argument controls the
output matrix row and column order and can be used to build a matrix for
a subset of graph nodes.
5 changes: 4 additions & 1 deletion rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def unweighted_average_shortest_path_length(graph, parallel_threshold=300, disco


@_rustworkx_dispatch
def adjacency_matrix(graph, weight_fn=None, default_weight=1.0, null_value=0.0):
def adjacency_matrix(graph, weight_fn=None, default_weight=1.0, null_value=0.0, node_list=None):
"""Return the adjacency matrix for a graph object

In the case where there are multiple edges between nodes the value in the
Expand Down Expand Up @@ -250,6 +250,9 @@ def adjacency_matrix(graph, weight_fn=None, default_weight=1.0, null_value=0.0):
value. This is the default value in the output matrix and it is used
to indicate the absence of an edge between 2 nodes. By default this is
``0.0``.
:param list node_list: Optional list of node indices used to determine the
row and column order of the output matrix. If fewer than all graph nodes
are provided, only edges between listed nodes are included.

:return: The adjacency matrix for the input dag as a numpy array
:rtype: numpy.ndarray
Expand Down
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def adjacency_matrix(
weight_fn: Callable[[_T], float] | None = ...,
default_weight: float = ...,
null_value: float = ...,
node_list: Sequence[int] | None = ...,
) -> npt.NDArray[np.float64]: ...
def all_simple_paths(
graph: PyGraph | PyDiGraph,
Expand Down
2 changes: 2 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def digraph_adjacency_matrix(
default_weight: float = ...,
null_value: float = ...,
parallel_edge: str = ...,
node_list: Sequence[int] | None = ...,
) -> npt.NDArray[np.float64]: ...
def graph_adjacency_matrix(
graph: PyGraph[_S, _T],
Expand All @@ -295,6 +296,7 @@ def graph_adjacency_matrix(
default_weight: float = ...,
null_value: float = ...,
parallel_edge: str = ...,
node_list: Sequence[int] | None = ...,
) -> npt.NDArray[np.float64]: ...
def cycle_basis(graph: PyGraph, /, root: int | None = ...) -> list[list[int]]: ...
def articulation_points(graph: PyGraph, /) -> set[int]: ...
Expand Down
84 changes: 73 additions & 11 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ mod all_pairs_all_simple_paths;
mod johnson_simple_cycles;
mod subgraphs;

use super::{
InvalidNode, NullGraph, digraph, get_edge_iter_with_weights, graph, score, weight_callable,
};
use super::{InvalidNode, NullGraph, digraph, graph, score, weight_callable};

use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
Expand Down Expand Up @@ -646,6 +644,50 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
}
}

fn adjacency_matrix_index_map<Ty: EdgeType>(
graph: &StablePyGraph<Ty>,
node_list: Option<Vec<usize>>,
) -> PyResult<(usize, Option<HashMap<usize, usize>>)> {
if let Some(nodes) = node_list {
let mut node_map = HashMap::with_capacity(nodes.len());
for (matrix_index, node) in nodes.into_iter().enumerate() {
let node_index = NodeIndex::new(node);
if !graph.contains_node(node_index) {
return Err(InvalidNode::new_err(format!(
"The input index {node} in 'node_list' is not a valid node index"
)));
}
if node_map.insert(node, matrix_index).is_some() {
return Err(PyValueError::new_err(
"node_list contains duplicate node indices",
));
}
}
return Ok((node_map.len(), Some(node_map)));
}

if graph.node_bound() != graph.node_count() {
let mut node_map = HashMap::with_capacity(graph.node_count());
for (matrix_index, node) in graph.node_indices().enumerate() {
node_map.insert(node.index(), matrix_index);
}
return Ok((graph.node_count(), Some(node_map)));
}

Ok((graph.node_count(), None))
}

fn adjacency_matrix_edge_indices(
source: usize,
target: usize,
node_map: &Option<HashMap<usize, usize>>,
) -> Option<(usize, usize)> {
match node_map {
Some(map) => Some((*map.get(&source)?, *map.get(&target)?)),
None => Some((source, target)),
}
}

/// Return the adjacency matrix for a PyDiGraph object
///
/// In the case where there are multiple edges between nodes the value in the
Expand Down Expand Up @@ -676,13 +718,16 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
/// :param String parallel_edge: Optional argument that determines how the function handles parallel edges.
/// ``"min"`` causes the value in the output matrix to be the minimum of the edges' weights, and similar behavior can be expected for ``"max"`` and ``"avg"``.
/// The function defaults to ``"sum"`` behavior, where the value in the output matrix is the sum of all parallel edge weights.
/// :param list node_list: Optional list of node indices used to determine the
/// row and column order of the output matrix. If fewer than all graph nodes
/// are provided, only edges between listed nodes are included.
///
/// :return: The adjacency matrix for the input directed graph as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[pyo3(
signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum"),
text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\")"
signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum", node_list=None),
text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\", node_list=None)"
)]
pub fn digraph_adjacency_matrix<'py>(
py: Python<'py>,
Expand All @@ -691,11 +736,18 @@ pub fn digraph_adjacency_matrix<'py>(
default_weight: f64,
null_value: f64,
parallel_edge: &str,
node_list: Option<Vec<usize>>,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let n = graph.node_count();
let (n, node_map) = adjacency_matrix_index_map(&graph.graph, node_list)?;
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
let mut parallel_edge_count = HashMap::new();
for (i, j, weight) in get_edge_iter_with_weights(&graph.graph) {
for edge in graph.graph.edge_references() {
let Some((i, j)) =
adjacency_matrix_edge_indices(edge.source().index(), edge.target().index(), &node_map)
else {
continue;
};
let weight = edge.weight().clone();
let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?;
if matrix[[i, j]] == null_value || (null_value.is_nan() && matrix[[i, j]].is_nan()) {
matrix[[i, j]] = edge_weight;
Expand Down Expand Up @@ -763,13 +815,16 @@ pub fn digraph_adjacency_matrix<'py>(
/// :param String parallel_edge: Optional argument that determines how the function handles parallel edges.
/// ``"min"`` causes the value in the output matrix to be the minimum of the edges' weights, and similar behavior can be expected for ``"max"`` and ``"avg"``.
/// The function defaults to ``"sum"`` behavior, where the value in the output matrix is the sum of all parallel edge weights.
/// :param list node_list: Optional list of node indices used to determine the
/// row and column order of the output matrix. If fewer than all graph nodes
/// are provided, only edges between listed nodes are included.
///
/// :return: The adjacency matrix for the input graph as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[pyo3(
signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum"),
text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\")"
signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum", node_list=None),
text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\", node_list=None)"
)]
pub fn graph_adjacency_matrix<'py>(
py: Python<'py>,
Expand All @@ -778,11 +833,18 @@ pub fn graph_adjacency_matrix<'py>(
default_weight: f64,
null_value: f64,
parallel_edge: &str,
node_list: Option<Vec<usize>>,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let n = graph.node_count();
let (n, node_map) = adjacency_matrix_index_map(&graph.graph, node_list)?;
let mut matrix = Array2::<f64>::from_elem((n, n), null_value);
let mut parallel_edge_count = HashMap::new();
for (i, j, weight) in get_edge_iter_with_weights(&graph.graph) {
for edge in graph.graph.edge_references() {
let Some((i, j)) =
adjacency_matrix_edge_indices(edge.source().index(), edge.target().index(), &node_map)
else {
continue;
};
let weight = edge.weight().clone();
let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?;
if matrix[[i, j]] == null_value || (null_value.is_nan() && matrix[[i, j]].is_nan()) {
matrix[[i, j]] = edge_weight;
Expand Down
64 changes: 64 additions & 0 deletions tests/digraph/test_adjacency_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,70 @@ def test_digraph_with_index_holes(self):
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(np.array([[0, 1], [0, 0]]), res))

def test_node_list_with_index_holes(self):
graph = rustworkx.PyDiGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
node_c = graph.add_node("c")
node_d = graph.add_node("d")
graph.add_edge(node_a, node_b, 1.0)
graph.add_edge(node_b, node_c, 2.0)
graph.add_edge(node_c, node_d, 3.0)
graph.add_edge(node_a, node_d, 4.0)
graph.remove_node(node_b)

res = rustworkx.digraph_adjacency_matrix(graph, lambda x: float(x))

self.assertTrue(
np.array_equal(
np.array([[0.0, 0.0, 4.0], [0.0, 0.0, 3.0], [0.0, 0.0, 0.0]]),
res,
)
)

res = rustworkx.digraph_adjacency_matrix(
graph, lambda x: float(x), node_list=[node_d, node_a, node_c]
)

self.assertTrue(
np.array_equal(
np.array([[0.0, 0.0, 0.0], [4.0, 0.0, 0.0], [3.0, 0.0, 0.0]]),
res,
)
)

def test_node_list_order_and_subset(self):
graph = rustworkx.PyDiGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
node_c = graph.add_node("c")
graph.add_edge(node_a, node_b, 1.0)
graph.add_edge(node_b, node_c, 2.0)
graph.add_edge(node_a, node_c, 3.0)

res = rustworkx.digraph_adjacency_matrix(graph, lambda x: float(x), node_list=[2, 0, 1])

self.assertTrue(
np.array_equal(
np.array([[0.0, 0.0, 0.0], [3.0, 0.0, 1.0], [2.0, 0.0, 0.0]]),
res,
)
)

res = rustworkx.adjacency_matrix(graph, lambda x: float(x), 1.0, 0.0, node_list=[2, 0])

self.assertTrue(np.array_equal(np.array([[0.0, 0.0], [3.0, 0.0]]), res))

def test_node_list_errors(self):
graph = rustworkx.PyDiGraph()
graph.add_node("a")

with self.assertRaises(rustworkx.InvalidNode):
rustworkx.digraph_adjacency_matrix(graph, node_list=[0, 1])

with self.assertRaises(ValueError):
rustworkx.digraph_adjacency_matrix(graph, node_list=[0, 0])

def test_from_adjacency_matrix(self):
input_array = np.array(
[[0.0, 4.0, 0.0], [4.0, 0.0, 4.0], [0.0, 4.0, 0.0]],
Expand Down
64 changes: 64 additions & 0 deletions tests/graph/test_adjacency_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,70 @@ def test_graph_with_index_holes(self):
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(np.array([[0, 1], [1, 0]]), res))

def test_node_list_with_index_holes(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
node_c = graph.add_node("c")
node_d = graph.add_node("d")
graph.add_edge(node_a, node_b, 1.0)
graph.add_edge(node_b, node_c, 2.0)
graph.add_edge(node_c, node_d, 3.0)
graph.add_edge(node_a, node_d, 4.0)
graph.remove_node(node_b)

res = rustworkx.graph_adjacency_matrix(graph, lambda x: float(x))

self.assertTrue(
np.array_equal(
np.array([[0.0, 0.0, 4.0], [0.0, 0.0, 3.0], [4.0, 3.0, 0.0]]),
res,
)
)

res = rustworkx.graph_adjacency_matrix(
graph, lambda x: float(x), node_list=[node_d, node_a, node_c]
)

self.assertTrue(
np.array_equal(
np.array([[0.0, 4.0, 3.0], [4.0, 0.0, 0.0], [3.0, 0.0, 0.0]]),
res,
)
)

def test_node_list_order_and_subset(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
node_c = graph.add_node("c")
graph.add_edge(node_a, node_b, 1.0)
graph.add_edge(node_b, node_c, 2.0)
graph.add_edge(node_a, node_c, 3.0)

res = rustworkx.graph_adjacency_matrix(graph, lambda x: float(x), node_list=[2, 0, 1])

self.assertTrue(
np.array_equal(
np.array([[0.0, 3.0, 2.0], [3.0, 0.0, 1.0], [2.0, 1.0, 0.0]]),
res,
)
)

res = rustworkx.adjacency_matrix(graph, lambda x: float(x), 1.0, 0.0, node_list=[2, 0])

self.assertTrue(np.array_equal(np.array([[0.0, 3.0], [3.0, 0.0]]), res))

def test_node_list_errors(self):
graph = rustworkx.PyGraph()
graph.add_node("a")

with self.assertRaises(rustworkx.InvalidNode):
rustworkx.graph_adjacency_matrix(graph, node_list=[0, 1])

with self.assertRaises(ValueError):
rustworkx.graph_adjacency_matrix(graph, node_list=[0, 0])

def test_from_adjacency_matrix(self):
input_array = np.array(
[[0.0, 4.0, 0.0], [4.0, 0.0, 4.0], [0.0, 4.0, 0.0]],
Expand Down
Loading