Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pychunkedgraph/graph/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def cross_edges_decorated(node_id):
return cross_edges_decorated(node_id)

def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None):
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
node_ids = np.asarray(node_ids, dtype=NODE_ID)
if not node_ids.size:
return node_ids
mask = np.in1d(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID))
Expand All @@ -93,7 +93,7 @@ def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None)

def children_multiple(self, node_ids: np.ndarray, *, flatten=False):
result = {}
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
node_ids = np.asarray(node_ids, dtype=NODE_ID)
if not node_ids.size:
return result
mask = np.in1d(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID))
Expand All @@ -111,7 +111,7 @@ def cross_chunk_edges_multiple(
self, node_ids: np.ndarray, *, time_stamp: datetime = None
):
result = {}
node_ids = np.array(node_ids, dtype=NODE_ID, copy=False)
node_ids = np.asarray(node_ids, dtype=NODE_ID)
if not node_ids.size:
return result
mask = np.in1d(
Expand Down
5 changes: 3 additions & 2 deletions pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel, unsupported-binary-operation

import time
import typing
import datetime
Expand Down Expand Up @@ -765,6 +764,7 @@ def add_edges(
source_coords: typing.Sequence[int] = None,
sink_coords: typing.Sequence[int] = None,
allow_same_segment_merge: typing.Optional[bool] = False,
stitch_mode: typing.Optional[bool] = False,
) -> operation.GraphEditOperation.Result:
"""
Adds an edge to the chunkedgraph
Expand All @@ -781,6 +781,7 @@ def add_edges(
source_coords=source_coords,
sink_coords=sink_coords,
allow_same_segment_merge=allow_same_segment_merge,
stitch_mode=stitch_mode,
).execute()

def remove_edges(
Expand Down Expand Up @@ -911,7 +912,7 @@ def get_chunk_coordinates_multiple(self, node_or_chunk_ids: typing.Sequence):
node_or_chunk_ids, dtype=basetypes.NODE_ID, copy=False
)
layers = self.get_chunk_layers(node_or_chunk_ids)
assert np.all(layers == layers[0]), "All IDs must have the same layer."
assert len(layers) == 0 or np.all(layers == layers[0]), "All IDs must have the same layer."
return chunk_utils.get_chunk_coordinates_multiple(self.meta, node_or_chunk_ids)

def get_chunk_id(
Expand Down
4 changes: 2 additions & 2 deletions pychunkedgraph/graph/chunks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_chunk_coordinates_multiple(meta, ids: np.ndarray) -> np.ndarray:
y_offset = x_offset - bits_per_dim
z_offset = y_offset - bits_per_dim

ids = np.array(ids, dtype=int, copy=False)
ids = np.asarray(ids, dtype=int)
X = ids >> x_offset & 2**bits_per_dim - 1
Y = ids >> y_offset & 2**bits_per_dim - 1
Z = ids >> z_offset & 2**bits_per_dim - 1
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_chunk_ids_from_node_ids(meta, ids: Iterable[np.uint64]) -> np.ndarray:
bits_per_dims = np.array([meta.bitmasks[l] for l in get_chunk_layers(meta, ids)])
offsets = 64 - meta.graph_config.LAYER_ID_BITS - 3 * bits_per_dims

ids = np.array(ids, dtype=int, copy=False)
ids = np.asarray(ids, dtype=int)
cids1 = np.array((ids >> offsets) << offsets, dtype=np.uint64)
# cids2 = np.vectorize(get_chunk_id)(meta, ids)
# assert np.all(cids1 == cids2)
Expand Down
10 changes: 4 additions & 6 deletions pychunkedgraph/graph/edges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,20 @@ def __init__(
affinities: Optional[np.ndarray] = None,
areas: Optional[np.ndarray] = None,
):
self.node_ids1 = np.array(node_ids1, dtype=basetypes.NODE_ID, copy=False)
self.node_ids2 = np.array(node_ids2, dtype=basetypes.NODE_ID, copy=False)
self.node_ids1 = np.array(node_ids1, dtype=basetypes.NODE_ID)
self.node_ids2 = np.array(node_ids2, dtype=basetypes.NODE_ID)
assert self.node_ids1.size == self.node_ids2.size

self._as_pairs = None

if affinities is not None and len(affinities) > 0:
self._affinities = np.array(
affinities, dtype=basetypes.EDGE_AFFINITY, copy=False
)
self._affinities = np.array(affinities, dtype=basetypes.EDGE_AFFINITY)
assert self.node_ids1.size == self._affinities.size
else:
self._affinities = np.full(len(self.node_ids1), DEFAULT_AFFINITY)

if areas is not None and len(areas) > 0:
self._areas = np.array(areas, dtype=basetypes.EDGE_AREA, copy=False)
self._areas = np.array(areas, dtype=basetypes.EDGE_AREA)
assert self.node_ids1.size == self._areas.size
else:
self._areas = np.full(len(self.node_ids1), DEFAULT_AREA)
Expand Down
57 changes: 37 additions & 20 deletions pychunkedgraph/graph/edits.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _analyze_affected_edges(


def _get_relevant_components(edges: np.ndarray, supervoxels: np.ndarray) -> Tuple:
edges = np.concatenate([edges, np.vstack([supervoxels, supervoxels]).T])
edges = np.concatenate([edges, np.vstack([supervoxels, supervoxels]).T]).astype(
basetypes.NODE_ID
)
graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True)
ccs = flatgraph.connected_components(graph)
relevant_ccs = []
Expand Down Expand Up @@ -107,8 +109,10 @@ def merge_preprocess(
active_edges.append(active)
inactive_edges.append(inactive)

relevant_ccs = _get_relevant_components(np.concatenate(active_edges), supervoxels)
inactive = np.concatenate(inactive_edges)
relevant_ccs = _get_relevant_components(
np.concatenate(active_edges).astype(basetypes.NODE_ID), supervoxels
)
inactive = np.concatenate(inactive_edges).astype(basetypes.NODE_ID)
_inactive = [types.empty_2d]
# source to sink edges
source_mask = np.in1d(inactive[:, 0], relevant_ccs[0])
Expand All @@ -119,7 +123,7 @@ def merge_preprocess(
sink_mask = np.in1d(inactive[:, 1], relevant_ccs[0])
source_mask = np.in1d(inactive[:, 0], relevant_ccs[1])
_inactive.append(inactive[source_mask & sink_mask])
_inactive = np.concatenate(_inactive)
_inactive = np.concatenate(_inactive).astype(basetypes.NODE_ID)
return np.unique(_inactive, axis=0) if _inactive.size else types.empty_2d


Expand Down Expand Up @@ -187,14 +191,15 @@ def add_edges(
time_stamp: datetime.datetime = None,
parent_ts: datetime.datetime = None,
allow_same_segment_merge=False,
stitch_mode: bool = False,
):
edges, l2_cross_edges_d = _analyze_affected_edges(
cg, atomic_edges, parent_ts=parent_ts
)
l2ids = np.unique(edges)
if not allow_same_segment_merge:
if not allow_same_segment_merge and not stitch_mode:
roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)
assert np.unique(roots).size == 2, "L2 IDs must belong to different roots."
assert np.unique(roots).size >= 2, "L2 IDs must belong to different roots."

new_old_id_d = defaultdict(set)
old_new_id_d = defaultdict(set)
Expand All @@ -217,7 +222,9 @@ def add_edges(

# update cache
# map parent to new merged children and vice versa
merged_children = np.concatenate([atomic_children_d[l2id] for l2id in l2ids_])
merged_children = np.concatenate(
[atomic_children_d[l2id] for l2id in l2ids_]
).astype(basetypes.NODE_ID)
cg.cache.children_cache[new_id] = merged_children
cache_utils.update(cg.cache.parents_cache, merged_children, new_id)

Expand All @@ -244,6 +251,7 @@ def add_edges(
operation_id=operation_id,
time_stamp=time_stamp,
parent_ts=parent_ts,
stitch_mode=stitch_mode,
)

new_roots = create_parents.run()
Expand Down Expand Up @@ -285,9 +293,8 @@ def _split_l2_agglomeration(
cross_edges = cross_edges[~in2d(cross_edges, removed_edges)]
isolated_ids = agg.supervoxels[~np.in1d(agg.supervoxels, chunk_edges)]
isolated_edges = np.column_stack((isolated_ids, isolated_ids))
graph, _, _, graph_ids = flatgraph.build_gt_graph(
np.concatenate([chunk_edges, isolated_edges]), make_directed=True
)
_edges = np.concatenate([chunk_edges, isolated_edges]).astype(basetypes.NODE_ID)
graph, _, _, graph_ids = flatgraph.build_gt_graph(_edges, make_directed=True)
return flatgraph.connected_components(graph), graph_ids, cross_edges


Expand Down Expand Up @@ -331,7 +338,7 @@ def remove_edges(
old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts)
chunk_id_map = dict(zip(l2ids.tolist(), cg.get_chunk_ids_from_node_ids(l2ids)))

removed_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]], axis=0)
removed_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]], axis=0).astype(basetypes.NODE_ID)
new_l2_ids = []
for id_ in l2ids:
agg = l2id_agglomeration_d[id_]
Expand Down Expand Up @@ -387,11 +394,11 @@ def _get_flipped_ids(id_map, node_ids):
returns old or new ids according to the map
"""
ids = [
np.array(list(id_map[id_]), dtype=basetypes.NODE_ID, copy=False)
np.asarray(list(id_map[id_]), dtype=basetypes.NODE_ID)
for id_ in node_ids
]
ids.append(types.empty_1d) # concatenate needs at least one array
return np.concatenate(ids)
return np.concatenate(ids).astype(basetypes.NODE_ID)


def _get_descendants(cg, new_id):
Expand Down Expand Up @@ -443,7 +450,7 @@ def _update_neighbor_cross_edges_single(
edges = fastremap.remap(edges, node_map, preserve_missing_labels=True)
if layer == counterpart_layer:
reverse_edge = np.array([counterpart, new_id], dtype=basetypes.NODE_ID)
edges = np.concatenate([edges, [reverse_edge]])
edges = np.concatenate([edges, [reverse_edge]]).astype(basetypes.NODE_ID)
descendants = _get_descendants(cg, new_id)
mask = np.isin(edges[:, 1], descendants)
if np.any(mask):
Expand Down Expand Up @@ -510,6 +517,7 @@ def __init__(
old_new_id_d: Dict[np.uint64, Set[np.uint64]] = None,
old_hierarchy_d: Dict[np.uint64, Dict[int, np.uint64]] = None,
parent_ts: datetime.datetime = None,
stitch_mode: bool = False,
):
self.cg = cg
self.new_entries = []
Expand All @@ -521,6 +529,7 @@ def __init__(
self._operation_id = operation_id
self._time_stamp = time_stamp
self._last_successful_ts = parent_ts
self.stitch_mode = stitch_mode

def _update_id_lineage(
self,
Expand Down Expand Up @@ -552,7 +561,7 @@ def _get_connected_components(self, node_ids: np.ndarray, layer: int):
for id_ in node_ids:
edges_ = cross_edges_d[id_].get(layer, types.empty_2d)
cx_edges.append(edges_)
cx_edges = np.concatenate([*cx_edges, np.vstack([node_ids, node_ids]).T])
cx_edges = np.concatenate([*cx_edges, np.vstack([node_ids, node_ids]).T]).astype(basetypes.NODE_ID)
graph, _, _, graph_ids = flatgraph.build_gt_graph(cx_edges, make_directed=True)
return flatgraph.connected_components(graph), graph_ids

Expand All @@ -568,7 +577,7 @@ def _get_layer_node_ids(
mask = np.in1d(siblings, old_ids)
node_ids = np.concatenate(
[_get_flipped_ids(self._old_new_id_d, old_ids), siblings[~mask], new_ids]
)
).astype(basetypes.NODE_ID)
node_ids = np.unique(node_ids)
layer_mask = self.cg.get_chunk_layers(node_ids) == layer
return node_ids[layer_mask]
Expand Down Expand Up @@ -635,10 +644,16 @@ def _create_new_parents(self, layer: int):
if len(cx_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0:
parent_layer = l
break
parent = self.cg.id_client.create_node_id(
self.cg.get_parent_chunk_id(cc_ids[0], parent_layer),
root_chunk=parent_layer == self.cg.meta.layer_count,
)

while True:
parent = self.cg.id_client.create_node_id(
self.cg.get_parent_chunk_id(cc_ids[0], parent_layer),
root_chunk=parent_layer == self.cg.meta.layer_count,
)
_entry = self.cg.client.read_node(parent)
if _entry == {}:
break

self._new_ids_d[parent_layer].append(parent)
self._update_id_lineage(parent, cc_ids, layer, parent_layer)
self.cg.cache.children_cache[parent] = cc_ids
Expand Down Expand Up @@ -689,6 +704,8 @@ def run(self) -> Iterable:
return self._new_ids_d[self.cg.meta.layer_count]

def _update_root_id_lineage(self):
if self.stitch_mode:
return
new_roots = self._new_ids_d[self.cg.meta.layer_count]
former_roots = _get_flipped_ids(self._new_old_id_d, new_roots)
former_roots = np.unique(former_roots)
Expand Down
Loading
Loading