Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tests/engines/test_multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,10 @@ class FakeVM:
)

# --- Call function ---
new_zarr, new_da = _save_multitask_vertical_to_cache(
new_zarr, new_da, zarr_group = _save_multitask_vertical_to_cache(
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=None,
probabilities=probabilities,
idx=idx,
tqdm_loop=tqdm_loop,
Expand All @@ -905,11 +906,44 @@ class FakeVM:

# new_zarr must be a real zarr array
assert isinstance(new_zarr[idx], zarr.Array)
assert zarr_group is not None

# Data was written correctly
assert np.array_equal(new_zarr[idx][:], np.array([[1, 2, 3]]))


def test_multitask_vertical_merge_continues_after_zarr_spill(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test multitask vertical merge appends all chunks after spilling to Zarr."""

class FakeVM:
"""Fake psutil.virtual_memory() with extremely low available memory."""

available = 1

monkeypatch.setattr(psutil, "virtual_memory", FakeVM)

values = np.arange(8 * 3, dtype=np.float32).reshape(8, 3, 1)
canvas = [da.from_array(values, chunks=(2, 3, 1))]
count = [da.from_array(np.ones_like(values), chunks=(2, 3, 1))]
output_locs_y = np.array([[0, 2], [2, 4], [4, 6], [6, 8]])

result = merge_multitask_vertical_chunkwise(
canvas=canvas,
count=count,
output_locs_y_=output_locs_y,
zarr_group=None,
save_path=tmp_path / "vertical.zarr",
memory_threshold=0,
output_shape=(8, 3),
verbose=False,
)

assert result[0].shape == values.shape
assert np.array_equal(result[0].compute(), values)


def test_qupath_feature_class_dict_lookup_fails() -> None:
"""Test qupath_feature_class_dict lookup fails."""
qupath_json = DaskDelayedJSONStore.__new__(DaskDelayedJSONStore)
Expand Down
26 changes: 26 additions & 0 deletions tests/engines/test_semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,32 @@ def test_merge_vertical_chunkwise_memory_threshold_triggered() -> None:
assert np.all(zarr_group["probabilities"][:] == data)


def test_merge_vertical_chunkwise_multi_row_overlap() -> None:
"""Test vertical merging when one row overlaps multiple following rows."""
rows = [
np.ones((4, 2, 1), dtype=np.float32),
np.ones((4, 2, 1), dtype=np.float32) * 2,
np.ones((4, 2, 1), dtype=np.float32) * 4,
]
data = np.concatenate(rows, axis=0)
canvas = da.from_array(data, chunks=(4, 2, 1))
count = da.from_array(np.ones_like(data, dtype=np.uint8), chunks=(4, 2, 1))
output_locs_y_ = np.array([[0, 4], [1, 5], [2, 6]])

result = merge_vertical_chunkwise(
canvas=canvas,
count=count,
output_locs_y_=output_locs_y_,
zarr_group=None,
save_path=Path("unused"),
verbose=False,
)

expected_rows = np.array([1, 1.5, 7 / 3, 7 / 3, 3, 4], dtype=np.float32)
expected = np.broadcast_to(expected_rows[:, None, None], (6, 2, 1))
np.testing.assert_allclose(result.compute(), expected)


def test_raise_value_error_return_labels_wsi(
remote_sample: Callable,
track_tmp_path: Path,
Expand Down
33 changes: 20 additions & 13 deletions tiatoolbox/models/engine/multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2608,19 +2608,24 @@ def merge_multitask_vertical_chunkwise(
chunk_shape=chunk_shape,
probabilities_zarr=probabilities_zarr[idx],
probabilities_da=probabilities_da[idx],
zarr_group=zarr_group,
zarr_group=(
zarr_group if probabilities_zarr[idx] is not None else None
),
Comment on lines +2611 to +2613
name=f"probabilities/{idx}",
)

probabilities_zarr, probabilities_da = _save_multitask_vertical_to_cache(
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
probabilities=probabilities,
idx=idx,
tqdm_loop=tqdm_loop,
save_path=save_path,
chunk_shape=chunk_shape,
memory_threshold=memory_threshold,
probabilities_zarr, probabilities_da, zarr_group = (
_save_multitask_vertical_to_cache(
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=zarr_group,
probabilities=probabilities,
idx=idx,
tqdm_loop=tqdm_loop,
save_path=save_path,
chunk_shape=chunk_shape,
memory_threshold=memory_threshold,
)
)

if next_chunk is not None:
Expand All @@ -2647,13 +2652,14 @@ def merge_multitask_vertical_chunkwise(
def _save_multitask_vertical_to_cache(
probabilities_zarr: list[zarr.Array] | list[None],
probabilities_da: list[da.Array] | list[None],
zarr_group: zarr.Group | None,
probabilities: np.ndarray,
idx: int,
tqdm_loop: tqdm,
save_path: Path,
chunk_shape: tuple,
memory_threshold: int = 80,
) -> tuple[list[zarr.Array], list[da.Array] | None]:
) -> tuple[list[zarr.Array], list[da.Array] | None, zarr.Group | None]:
"""Helper function to save to zarr if vertical merge is out of memory."""
used_percent = 0
if probabilities_da[idx] is not None:
Expand All @@ -2669,7 +2675,8 @@ def _save_multitask_vertical_to_cache(
f"Saving intermediate results to disk."
)
update_tqdm_desc(tqdm_loop=tqdm_loop, desc=msg)
zarr_group = zarr.open(str(save_path), mode="a")
if zarr_group is None:
zarr_group = zarr.open(str(save_path), mode="a")
probabilities_zarr[idx] = zarr_group.create_array(
name=f"probabilities/{idx}",
shape=probabilities_da[idx].shape,
Expand All @@ -2681,7 +2688,7 @@ def _save_multitask_vertical_to_cache(
update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc)
probabilities_da[idx] = None

return probabilities_zarr, probabilities_da
return probabilities_zarr, probabilities_da, zarr_group


def _clear_zarr(
Expand Down
179 changes: 136 additions & 43 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,84 @@ def get_wsi_output_shape(dataset: object) -> tuple[int, int] | None:
return int(wsi_shape[1]), int(wsi_shape[0])


def _get_vertical_chunk_locations(
output_locs_y: np.ndarray,
num_chunks: int,
) -> np.ndarray:
"""Return unique vertical chunk locations in processing order."""
chunk_locs = np.unique(output_locs_y, axis=0)
chunk_locs = chunk_locs[np.argsort(chunk_locs[:, 0], kind="stable")]
if len(chunk_locs) != num_chunks:
msg = (
"Number of vertical output locations does not match the number "
"of merged canvas chunks."
)
raise ValueError(msg)
return chunk_locs.astype(np.int64, copy=False)


def _aggregate_vertical_segment(
active_chunks: list[tuple[int, int, np.ndarray, np.ndarray]],
start_y: int,
end_y: int,
) -> np.ndarray:
"""Average all active chunks covering a finalized vertical segment."""
if end_y <= start_y:
return np.empty((0, *active_chunks[0][2].shape[1:]))

segment_shape = (end_y - start_y, *active_chunks[0][2].shape[1:])
segment = np.zeros(segment_shape, dtype=active_chunks[0][2].dtype)
segment_count = np.zeros(
(end_y - start_y, *active_chunks[0][3].shape[1:]),
dtype=np.uint32,
)

for chunk_start_y, chunk_end_y, chunk, chunk_count in active_chunks:
overlap_start = max(start_y, chunk_start_y)
overlap_end = min(end_y, chunk_end_y)
if overlap_end <= overlap_start:
continue

source_start = overlap_start - chunk_start_y
source_end = overlap_end - chunk_start_y
target_start = overlap_start - start_y
target_end = overlap_end - start_y

segment[target_start:target_end] += chunk[source_start:source_end]
segment_count[target_start:target_end] += chunk_count[source_start:source_end]

segment_count = np.where(segment_count == 0, 1, segment_count)
return segment / segment_count.astype(np.float32)


def _store_vertical_segment(
probabilities: np.ndarray,
output_shape: tuple[int, int] | None,
written_height: int,
chunk_shape: tuple[int, ...],
probabilities_zarr: zarr.Array | None,
probabilities_da: da.Array | None,
zarr_group: zarr.Group | None,
) -> tuple[zarr.Array | None, da.Array | None, int, bool]:
"""Clip and store a finalized vertical probability segment."""
probabilities, written_height, should_stop = clip_probabilities_to_shape(
probabilities=probabilities,
output_shape=output_shape,
written_height=written_height,
)
if should_stop or probabilities.shape[0] == 0:
return probabilities_zarr, probabilities_da, written_height, should_stop

probabilities_zarr, probabilities_da = store_probabilities(
probabilities=probabilities,
chunk_shape=chunk_shape,
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=zarr_group,
)
return probabilities_zarr, probabilities_da, written_height, False


def merge_vertical_chunkwise(
canvas: da.Array,
count: da.Array,
Expand All @@ -1410,8 +1488,8 @@ def merge_vertical_chunkwise(

This function processes vertically stacked image blocks (`canvas`) and their
associated count arrays to compute normalized probabilities. It handles overlapping
regions between chunks by applying seam folding and trimming halos to ensure smooth
transitions. If a Zarr group is provided, the result is stored incrementally.
regions between chunks by keeping active rows until no later chunk can contribute
to them. If a Zarr group is provided, the result is stored incrementally.

Args:
canvas (da.Array):
Expand Down Expand Up @@ -1441,54 +1519,61 @@ def merge_vertical_chunkwise(
or constructed in memory.

"""
y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1])
overlaps = np.append(y1s[:-1] - y0s[1:], 0)

num_chunks = canvas.numblocks[0]
probabilities_zarr, probabilities_da = None, None
chunk_shape = tuple(chunk[0] for chunk in canvas.chunks)
written_height = 0
chunk_locs = _get_vertical_chunk_locations(output_locs_y_, num_chunks)

tqdm_loop = tqdm(
overlaps,
range(num_chunks),
leave=False,
desc="Merging rows",
disable=not verbose,
)

used_percent = 0

curr_chunk = canvas.blocks[0, 0].compute()
curr_count = count.blocks[0, 0].compute()
next_chunk = canvas.blocks[1, 0].compute() if num_chunks > 1 else None
next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None

active_chunks: list[tuple[int, int, np.ndarray, np.ndarray]] = []
probabilities = np.empty(0)
current_y = int(chunk_locs[0, 0])
should_stop = False

for i, overlap in enumerate(tqdm_loop):
if next_chunk is not None and overlap > 0:
curr_chunk[-overlap:] += next_chunk[:overlap]
curr_count[-overlap:] += next_count[:overlap]

# Normalize
curr_count = np.where(curr_count == 0, 1, curr_count)
probabilities = curr_chunk / curr_count.astype(np.float32)
for chunk_idx in tqdm_loop:
chunk_start_y, chunk_end_y = map(int, chunk_locs[chunk_idx])

probabilities, written_height, should_stop = clip_probabilities_to_shape(
probabilities=probabilities,
output_shape=output_shape,
written_height=written_height,
)
if should_stop:
break

probabilities_zarr, probabilities_da = store_probabilities(
probabilities=probabilities,
chunk_shape=chunk_shape,
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=zarr_group,
)
if active_chunks and chunk_start_y > current_y:
probabilities = _aggregate_vertical_segment(
active_chunks=active_chunks,
start_y=current_y,
end_y=chunk_start_y,
)
probabilities_zarr, probabilities_da, written_height, should_stop = (
_store_vertical_segment(
probabilities=probabilities,
output_shape=output_shape,
written_height=written_height,
chunk_shape=chunk_shape,
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=zarr_group,
)
)
if should_stop:
break

current_y = chunk_start_y
active_chunks = [
active_chunk
for active_chunk in active_chunks
if active_chunk[1] > current_y
]

chunk = canvas.blocks[chunk_idx, 0].compute()
chunk_count = count.blocks[chunk_idx, 0].compute()
valid_chunk_end_y = min(chunk_end_y, chunk_start_y + chunk.shape[0])
if valid_chunk_end_y > chunk_start_y:
active_chunks.append((chunk_start_y, valid_chunk_end_y, chunk, chunk_count))

if probabilities_da is not None:
vm = psutil.virtual_memory()
Expand All @@ -1514,16 +1599,24 @@ def merge_vertical_chunkwise(
probabilities_da = None
update_tqdm_desc(tqdm_loop=tqdm_loop, desc=desc)

if next_chunk is not None:
curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:]

if i + 2 < num_chunks:
next_chunk = canvas.blocks[i + 2, 0].compute()
next_count = count.blocks[i + 2, 0].compute()
else:
next_chunk, next_count = None, None
if active_chunks and not should_stop:
final_y = max(active_chunk[1] for active_chunk in active_chunks)
probabilities = _aggregate_vertical_segment(
active_chunks=active_chunks,
start_y=current_y,
end_y=final_y,
)
probabilities_zarr, probabilities_da, _, _ = _store_vertical_segment(
probabilities=probabilities,
output_shape=output_shape,
written_height=written_height,
chunk_shape=chunk_shape,
probabilities_zarr=probabilities_zarr,
probabilities_da=probabilities_da,
zarr_group=zarr_group,
)

if probabilities_zarr:
if probabilities_zarr is not None:
return _get_probabilities_da_from_zarr(
zarr_group=zarr_group,
probabilities_zarr=probabilities_zarr,
Expand Down
Loading