Skip to content

Commit 4ccfa5f

Browse files
committed
Only reindex after deleting all the ROIs in multiple removal
1 parent 849ca11 commit 4ccfa5f

3 files changed

Lines changed: 259 additions & 7 deletions

File tree

cellpose/gui/delete_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Utilities for deleting and relabeling GUI masks.
3+
"""
4+
5+
import numpy as np
6+
7+
8+
def normalize_remove_ids(remove_ids, ncells):
9+
"""Return unique valid label IDs in descending order."""
10+
remove_ids = np.asarray(remove_ids, dtype=np.int64).reshape(-1)
11+
if remove_ids.size == 0 or ncells <= 0:
12+
return np.zeros(0, dtype=np.int64)
13+
valid = (remove_ids > 0) & (remove_ids <= int(ncells))
14+
if not np.any(valid):
15+
return np.zeros(0, dtype=np.int64)
16+
remove_ids = np.unique(remove_ids[valid])
17+
return np.sort(remove_ids)[::-1]
18+
19+
20+
def batch_delete_reindex(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids):
21+
"""Delete labels and reindex all state in one pass.
22+
23+
Returns updated `(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids, remove_mask)`.
24+
"""
25+
if cellpix.shape != outpix.shape:
26+
raise ValueError("cellpix and outpix must have the same shape")
27+
28+
ncells = int(len(cellcolors) - 1)
29+
remove_ids = normalize_remove_ids(remove_ids, ncells)
30+
if remove_ids.size == 0:
31+
remove_mask = np.zeros(ncells + 1, dtype=bool)
32+
return (
33+
cellpix,
34+
outpix,
35+
ismanual,
36+
cellcolors,
37+
list(zdraw),
38+
remove_ids,
39+
remove_mask,
40+
)
41+
42+
remove_mask = np.zeros(ncells + 1, dtype=bool)
43+
remove_mask[remove_ids] = True
44+
keep_mask = ~remove_mask
45+
46+
lut_dtype = cellpix.dtype if np.issubdtype(cellpix.dtype, np.integer) else np.int64
47+
relabel_map = np.cumsum(keep_mask, dtype=lut_dtype) - 1
48+
relabel_map[remove_mask] = 0
49+
50+
cellpix = relabel_map[cellpix]
51+
outpix = relabel_map[outpix]
52+
ismanual = ismanual[keep_mask[1:]]
53+
cellcolors = cellcolors[keep_mask]
54+
zdraw = [z for z, keep in zip(zdraw, keep_mask[1:]) if keep]
55+
56+
return cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids, remove_mask

cellpose/gui/gui.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy.stats import mode
1616
import cv2
1717

18-
from . import guiparts, menus, io
18+
from . import guiparts, menus, io, delete_utils
1919
from .. import models, core, dynamics, version, train
2020
from ..utils import download_url_to_file, masks_to_outlines, diameters
2121
from ..io import get_image_files, imsave, imread
@@ -1090,12 +1090,60 @@ def unselect_cell_multi(self, idx):
10901090
def remove_cell(self, idx):
10911091
if isinstance(idx, (int, np.integer)):
10921092
idx = [idx]
1093-
# because the function remove_single_cell updates the state of the cellpix and outpix arrays
1094-
# by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
1095-
# so that the indices are correct
1096-
idx.sort(reverse=True)
1097-
for i in idx:
1098-
self.remove_single_cell(i)
1093+
idx = delete_utils.normalize_remove_ids(idx, self.ncells.get())
1094+
if idx.size == 0:
1095+
return
1096+
1097+
if idx.size == 1:
1098+
self.remove_single_cell(int(idx[0]))
1099+
else:
1100+
self.selected = 0
1101+
remove_mask = np.zeros(self.ncells.get() + 1, dtype=bool)
1102+
remove_mask[idx] = True
1103+
1104+
if self.currentZ < self.cellpix.shape[0]:
1105+
self.layerz[remove_mask[self.cellpix[self.currentZ]]] = np.array([0, 0, 0,
1106+
0])
1107+
1108+
if self.NZ == 1:
1109+
last_idx = int(idx[-1])
1110+
cp_last = self.cellpix[0] == last_idx
1111+
op_last = self.outpix[0] == last_idx
1112+
self.removed_cell = [
1113+
self.ismanual[last_idx - 1], self.cellcolors[last_idx],
1114+
np.nonzero(cp_last),
1115+
np.nonzero(op_last)
1116+
]
1117+
self.redo.setEnabled(True)
1118+
1119+
ar_all, ac_all = np.nonzero(remove_mask[self.cellpix[0]])
1120+
coord_map = {}
1121+
if ar_all.size > 0:
1122+
labels = self.cellpix[0, ar_all, ac_all]
1123+
order = np.argsort(labels, kind="mergesort")
1124+
labels = labels[order]
1125+
ar_all = ar_all[order]
1126+
ac_all = ac_all[order]
1127+
unique_labels, first_inds = np.unique(labels, return_index=True)
1128+
last_inds = np.append(first_inds[1:], labels.size)
1129+
for label, i0, i1 in zip(unique_labels, first_inds, last_inds):
1130+
coord_map[int(label)] = (ar_all[i0:i1], ac_all[i0:i1])
1131+
1132+
for i in idx:
1133+
ar, ac = coord_map.get(
1134+
int(i), (np.zeros(0, np.int64), np.zeros(0, np.int64)))
1135+
d = datetime.datetime.now()
1136+
self.track_changes.append(
1137+
[d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
1138+
print("GUI_INFO: removed cell %d" % (i - 1))
1139+
else:
1140+
for i in idx:
1141+
print("GUI_INFO: removed cell %d" % (i - 1))
1142+
1143+
(self.cellpix, self.outpix, self.ismanual, self.cellcolors, self.zdraw, _,
1144+
_) = delete_utils.batch_delete_reindex(self.cellpix, self.outpix,
1145+
self.ismanual, self.cellcolors,
1146+
self.zdraw, idx)
10991147
self.ncells -= len(idx) # _save_sets uses ncells
11001148
self.update_layer()
11011149

tests/test_gui_delete_utils.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
import importlib.util
3+
from pathlib import Path
4+
5+
6+
_DELETE_UTILS_PATH = (
7+
Path(__file__).resolve().parents[1] / "cellpose/gui/delete_utils.py"
8+
)
9+
_DELETE_UTILS_SPEC = importlib.util.spec_from_file_location(
10+
"cellpose_gui_delete_utils", _DELETE_UTILS_PATH
11+
)
12+
delete_utils = importlib.util.module_from_spec(_DELETE_UTILS_SPEC)
13+
_DELETE_UTILS_SPEC.loader.exec_module(delete_utils)
14+
15+
16+
def _legacy_remove_state(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids):
17+
cellpix = cellpix.copy()
18+
outpix = outpix.copy()
19+
ismanual = ismanual.copy()
20+
cellcolors = cellcolors.copy()
21+
zdraw = list(zdraw)
22+
23+
for idx in remove_ids:
24+
cp = cellpix == idx
25+
op = outpix == idx
26+
cellpix[cp] = 0
27+
outpix[op] = 0
28+
cellpix[cellpix > idx] -= 1
29+
outpix[outpix > idx] -= 1
30+
ismanual = np.delete(ismanual, idx - 1)
31+
cellcolors = np.delete(cellcolors, [idx], axis=0)
32+
del zdraw[idx - 1]
33+
34+
return cellpix, outpix, ismanual, cellcolors, zdraw
35+
36+
37+
def _random_state(seed, nz=1, ly=64, lx=64, ncells=40):
38+
rng = np.random.default_rng(seed)
39+
dtype = np.uint16 if ncells < 2**16 - 1 else np.uint32
40+
41+
cellpix = rng.integers(0, ncells + 1, size=(nz, ly, lx), dtype=dtype)
42+
force_idx = rng.choice(cellpix.size, size=ncells, replace=False)
43+
cellpix.flat[force_idx] = np.arange(1, ncells + 1, dtype=dtype)
44+
45+
outline_mask = rng.random(cellpix.shape) < 0.2
46+
outpix = np.where(outline_mask, cellpix, 0).astype(dtype, copy=False)
47+
48+
ismanual = rng.integers(0, 2, size=ncells, dtype=np.uint8).astype(bool)
49+
cellcolors = rng.integers(0, 256, size=(ncells + 1, 3), dtype=np.uint8)
50+
cellcolors[0] = np.array([255, 255, 255], dtype=np.uint8)
51+
52+
zdraw = []
53+
for _ in range(ncells):
54+
nplanes = int(rng.integers(1, max(2, nz + 1)))
55+
zdraw.append(list(rng.integers(0, max(1, nz), size=nplanes)))
56+
57+
return cellpix, outpix, ismanual, cellcolors, zdraw
58+
59+
60+
def _assert_state_equal(expected, actual):
61+
exp_cellpix, exp_outpix, exp_ismanual, exp_cellcolors, exp_zdraw = expected
62+
got_cellpix, got_outpix, got_ismanual, got_cellcolors, got_zdraw = actual
63+
64+
assert np.array_equal(exp_cellpix, got_cellpix)
65+
assert np.array_equal(exp_outpix, got_outpix)
66+
assert np.array_equal(exp_ismanual, got_ismanual)
67+
assert np.array_equal(exp_cellcolors, got_cellcolors)
68+
assert len(exp_zdraw) == len(got_zdraw)
69+
for z0, z1 in zip(exp_zdraw, got_zdraw):
70+
assert np.array_equal(np.asarray(z0), np.asarray(z1))
71+
72+
73+
def test_batch_delete_reindex_matches_legacy_small_example():
74+
cellpix = np.array(
75+
[[[1, 1, 2, 2], [1, 3, 3, 2], [4, 4, 5, 5], [4, 0, 5, 5]]], dtype=np.uint16
76+
)
77+
outpix = np.array(
78+
[[[1, 0, 2, 0], [0, 3, 0, 2], [4, 0, 5, 0], [0, 0, 0, 5]]], dtype=np.uint16
79+
)
80+
ismanual = np.array([True, False, True, False, True])
81+
cellcolors = np.array(
82+
[[255, 255, 255], [10, 0, 0], [20, 0, 0], [30, 0, 0], [40, 0, 0], [50, 0, 0]],
83+
dtype=np.uint8,
84+
)
85+
zdraw = [[0], [0], [0], [0], [0]]
86+
87+
remove_ids = np.array([5, 3, 2], dtype=np.int64)
88+
89+
expected = _legacy_remove_state(
90+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
91+
)
92+
got = delete_utils.batch_delete_reindex(
93+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
94+
)[:5]
95+
_assert_state_equal(expected, got)
96+
97+
98+
def test_batch_delete_reindex_matches_legacy_random_2d():
99+
for seed in range(20):
100+
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(seed, nz=1)
101+
rng = np.random.default_rng(seed + 1000)
102+
ncells = len(cellcolors) - 1
103+
remove_n = int(rng.integers(1, ncells + 1))
104+
remove_ids = rng.choice(np.arange(1, ncells + 1), size=remove_n, replace=False)
105+
remove_ids = delete_utils.normalize_remove_ids(remove_ids, ncells)
106+
107+
expected = _legacy_remove_state(
108+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
109+
)
110+
got = delete_utils.batch_delete_reindex(
111+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
112+
)[:5]
113+
_assert_state_equal(expected, got)
114+
115+
116+
def test_batch_delete_reindex_matches_legacy_random_3d():
117+
for seed in range(12):
118+
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(seed + 100, nz=4)
119+
rng = np.random.default_rng(seed + 2000)
120+
ncells = len(cellcolors) - 1
121+
remove_n = int(rng.integers(1, ncells + 1))
122+
remove_ids = rng.choice(np.arange(1, ncells + 1), size=remove_n, replace=False)
123+
remove_ids = delete_utils.normalize_remove_ids(remove_ids, ncells)
124+
125+
expected = _legacy_remove_state(
126+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
127+
)
128+
got = delete_utils.batch_delete_reindex(
129+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
130+
)[:5]
131+
_assert_state_equal(expected, got)
132+
133+
134+
def test_batch_delete_reindex_noop_invalid_ids():
135+
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(999, nz=1)
136+
ncells = len(cellcolors) - 1
137+
remove_ids = np.array([0, -1, ncells + 10], dtype=np.int64)
138+
139+
got = delete_utils.batch_delete_reindex(
140+
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
141+
)
142+
got_state = got[:5]
143+
out_ids = got[5]
144+
remove_mask = got[6]
145+
146+
_assert_state_equal((cellpix, outpix, ismanual, cellcolors, zdraw), got_state)
147+
assert out_ids.size == 0
148+
assert not remove_mask.any()

0 commit comments

Comments
 (0)