Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.options import get_options
from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype
from tracksdata.utils._signal import iter_node_added_events, iter_node_updated_events

if TYPE_CHECKING:
from tracksdata.nodes._mask import Mask
Expand Down Expand Up @@ -415,15 +416,24 @@ def _invalidate_from_attrs(self, attrs: dict) -> None:
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)

def _on_node_added(self, node_id: int, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(new_attrs)
def _on_node_added(
self,
node_id: int | Sequence[int],
new_attrs: dict | Sequence[dict],
) -> None:
for _, attrs in iter_node_added_events(node_id, new_attrs):
self._invalidate_from_attrs(attrs)

def _on_node_removed(self, node_id: int, old_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)

def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)
self._invalidate_from_attrs(new_attrs)
def _on_node_updated(
self,
node_id: int | Sequence[int],
old_attrs: dict | Sequence[dict],
new_attrs: dict | Sequence[dict],
) -> None:
for _, old_attr, new_attr in iter_node_updated_events(node_id, old_attrs, new_attrs):
self._invalidate_from_attrs(old_attr)
self._invalidate_from_attrs(new_attr)
64 changes: 64 additions & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,38 @@ def remove_node(self, node_id: int) -> None:
If the node_id does not exist in the graph.
"""

def bulk_remove_nodes(self, node_ids: Sequence[int]) -> None:
"""
Remove multiple nodes from the graph, along with their incident edges.

Existence is validated up-front so the call either removes every node
in `node_ids` or raises without modifying the graph.

Parameters
----------
node_ids : Sequence[int]
The IDs of the nodes to remove.

Raises
------
ValueError
If any node_id does not exist in the graph.
"""
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()
else:
node_ids = list(node_ids)
if len(node_ids) == 0:
return

existing = set(self.node_ids())
missing = [nid for nid in node_ids if nid not in existing]
if missing:
raise ValueError(f"Node {missing[0]} does not exist in the graph.")

for node_id in node_ids:
self.remove_node(node_id)

@abc.abstractmethod
def add_edge(
self,
Expand Down Expand Up @@ -314,6 +346,38 @@ def remove_edge(
If the specified edge does not exist or insufficient identifiers are provided.
"""

def bulk_remove_edges(self, edge_ids: Sequence[int]) -> None:
"""
Remove multiple edges from the graph by their edge IDs.

Existence is validated up-front so the call either removes every edge
in `edge_ids` or raises without modifying the graph.

Parameters
----------
edge_ids : Sequence[int]
The IDs of the edges to remove.

Raises
------
ValueError
If any edge_id does not exist in the graph.
"""
if hasattr(edge_ids, "tolist"):
edge_ids = edge_ids.tolist()
else:
edge_ids = list(edge_ids)
if len(edge_ids) == 0:
return

existing = set(self.edge_ids())
missing = [eid for eid in edge_ids if eid not in existing]
if missing:
raise ValueError(f"Edge {missing[0]} does not exist in the graph.")

for edge_id in edge_ids:
self.remove_edge(edge_id=edge_id)

@overload
def bulk_add_edges(
self,
Expand Down
147 changes: 115 additions & 32 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph, RXFilter
from tracksdata.graph.filters._indexed_filter import IndexRXFilter
from tracksdata.utils._dtypes import AttrSchema
from tracksdata.utils._signal import is_signal_on
from tracksdata.utils._signal import (
emit_node_added_events,
emit_node_updated_events,
is_signal_on,
)


class GraphView(MappedGraphMixin, RustWorkXGraph):
Expand Down Expand Up @@ -390,12 +394,8 @@ def add_node(
)

if self.sync:
with self.node_added.blocked():
node_id = RustWorkXGraph.add_node(
self,
attrs=attrs,
validate_keys=validate_keys,
)
# Local primitive: pure rx_graph + _time_to_nodes, no validation, no signal.
node_id = self._bulk_add_nodes_local([attrs])[0]
self._add_id_mapping(node_id, parent_node_id)
else:
self._out_of_sync = True
Expand All @@ -411,20 +411,20 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
with self._root.node_added.blocked():
parent_node_ids = self._root.bulk_add_nodes(nodes, indices=indices)

# Defensive: drop NODE_ID from emitted/local-stored attrs in case the root
# backend (e.g. older SQL paths) injected it.
emitted_nodes = [
{key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
for node_attrs in nodes
]
if self.sync:
with self.node_added.blocked():
node_ids = RustWorkXGraph.bulk_add_nodes(self, nodes)
node_ids = self._bulk_add_nodes_local(emitted_nodes)
self._add_id_mappings(list(zip(node_ids, parent_node_ids, strict=True)))
else:
self._out_of_sync = True

if is_signal_on(self._root.node_added):
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self._root.node_added.emit(node_id, node_attrs)

if is_signal_on(self.node_added):
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self.node_added.emit(node_id, node_attrs)
emit_node_added_events(self._root.node_added, zip(parent_node_ids, emitted_nodes, strict=True))
emit_node_added_events(self.node_added, zip(parent_node_ids, emitted_nodes, strict=True))

return parent_node_ids

Expand Down Expand Up @@ -462,11 +462,9 @@ def remove_node(self, node_id: int) -> None:
self._root.remove_node(node_id)

if self.sync:
# Get the local node ID and remove from local graph
# Local primitive: pure rx_graph + _time_to_nodes, no signal.
local_node_id = self._external_to_local[node_id]

with self.node_removed.blocked():
super().remove_node(local_node_id)
self._bulk_remove_nodes_local([local_node_id])

# Remove the node mapping
self._remove_id_mapping(external_id=node_id)
Expand All @@ -490,6 +488,61 @@ def remove_node(self, node_id: int) -> None:
if view_signal_on:
self.node_removed.emit(node_id, old_attrs)

def bulk_remove_nodes(self, node_ids: Sequence[int]) -> None:
"""
Remove multiple nodes from both the view and the root graph.

Parameters
----------
node_ids : Sequence[int]
External IDs of the nodes to remove.

Raises
------
ValueError
If any node_id does not exist in the graph.
"""
if hasattr(node_ids, "tolist"):
node_ids = node_ids.tolist()
else:
node_ids = list(node_ids)
if len(node_ids) == 0:
return

missing = [nid for nid in node_ids if nid not in self._external_to_local]
if missing:
raise ValueError(f"Node {missing[0]} does not exist in the graph.")

view_signal_on = is_signal_on(self.node_removed)
root_signal_on = is_signal_on(self._root.node_removed)
old_attrs_per_node: dict[int, dict[str, Any]] = {}
if view_signal_on or root_signal_on:
for nid in node_ids:
old_attrs_per_node[nid] = self.nodes[nid].to_dict()

with self._root.node_removed.blocked():
self._root.bulk_remove_nodes(node_ids)

if self.sync:
local_ids = [self._external_to_local[nid] for nid in node_ids]
self._bulk_remove_nodes_local(local_ids)
for nid in node_ids:
self._remove_id_mapping(external_id=nid)

edge_indices = set(self.rx_graph.edge_indices())
for local_edge_id in list(self._edge_map_to_root.keys()):
if local_edge_id not in edge_indices:
del self._edge_map_to_root[local_edge_id]
else:
self._out_of_sync = True

if root_signal_on:
for nid in node_ids:
self._root.node_removed.emit(nid, old_attrs_per_node[nid])
if view_signal_on:
for nid in node_ids:
self.node_removed.emit(nid, old_attrs_per_node[nid])

def add_edge(
self,
source_id: int,
Expand Down Expand Up @@ -576,6 +629,40 @@ def remove_edge(
else:
self._out_of_sync = True

def bulk_remove_edges(self, edge_ids: Sequence[int]) -> None:
"""
Remove multiple edges from both the root and (if present) the view.

Parameters
----------
edge_ids : Sequence[int]
Root edge IDs to remove.

Raises
------
ValueError
If any edge_id does not exist in the root graph.
"""
if hasattr(edge_ids, "tolist"):
edge_ids = edge_ids.tolist()
else:
edge_ids = list(edge_ids)
if len(edge_ids) == 0:
return

self._root.bulk_remove_edges(edge_ids)

if self.sync:
edge_map = self.rx_graph.edge_index_map()
for root_eid in edge_ids:
if root_eid in self._edge_map_from_root:
local_edge_id = self._edge_map_from_root[root_eid]
src, tgt, _ = edge_map[local_edge_id]
self.rx_graph.remove_edge(src, tgt)
del self._edge_map_to_root[local_edge_id]
else:
self._out_of_sync = True

def _get_neighbors(
self,
neighbors_func: Callable[[rx.PyDiGraph, int], rx.NodeIndices],
Expand Down Expand Up @@ -747,19 +834,15 @@ def update_node_attrs(
)
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
if root_signal_on:
for node_id in node_ids:
self._root.node_updated.emit(
node_id,
old_attrs_by_id[node_id],
new_attrs_by_id[node_id],
)
emit_node_updated_events(
self._root.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids),
)
if view_signal_on:
for node_id in node_ids:
self.node_updated.emit(
node_id,
old_attrs_by_id[node_id],
new_attrs_by_id[node_id],
)
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in node_ids),
)

def update_edge_attrs(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/tracksdata/graph/_mapped_graph_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def _add_id_mappings(self, mappings: Sequence[tuple[int, int]]) -> None:
mappings : Sequence[tuple[int, int]]
Sequence of (local_id, external_id) pairs
"""
self._local_to_external.putall(mappings)
try:
self._local_to_external.putall(mappings)
except bidict.ValueDuplicationError as e:
# Match the single-add path: an external_id collision is the user-facing "key" duplication.
raise bidict.KeyDuplicationError(e.args[0]) from e

def _remove_id_mapping(
self,
Expand Down
Loading
Loading