-
Notifications
You must be signed in to change notification settings - Fork 102
🔨 Fix SemanticSegmentor Memory Spill Issue
#990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-define-engines-abc
Are you sure you want to change the base?
🔨 Fix SemanticSegmentor Memory Spill Issue
#990
Conversation
SemanticSegmentor Memory Spill IssueSemanticSegmentor Memory Spill Issue
SemanticSegmentor Memory Spill IssueSemanticSegmentor Memory Spill Issue
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## dev-define-engines-abc #990 +/- ##
=========================================================
Coverage ? 95.12%
=========================================================
Files ? 79
Lines ? 9996
Branches ? 1288
=========================================================
Hits ? 9509
Misses ? 445
Partials ? 42 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR targets out-of-memory crashes in SemanticSegmentor.infer_wsi() when processing large WSIs by reducing peak RAM usage and spilling intermediate results to disk more aggressively.
Changes:
- Adds disk-backed intermediates (Zarr/temp dirs) to reduce RAM spikes during WSI merging and masked-output alignment.
- Adjusts horizontal merging logic to bound per-row allocations.
- Updates zarr-output tests to tolerate slightly higher mean prediction values.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
tiatoolbox/models/engine/semantic_segmentor.py |
Refactors WSI inference/merging and introduces new disk-backed caching/temp-dir handling intended to prevent OOM. |
tests/engines/test_semantic_segmentor.py |
Relaxes mean-prediction assertion upper bounds for WSI zarr outputs. |
Comments suppressed due to low confidence (3)
tiatoolbox/models/engine/semantic_segmentor.py:571
- When intermediate spilling occurs (canvas_zarr is not None), the remaining in-memory
canvas/countproduced after the last spill are never appended to the Zarr datasets anymore. The code then overwritescanvas/countwithda.from_zarr(...), effectively dropping the tail rows and producing truncated outputs. Restore a final flush intocanvas_zarr/count_zarrbefore wrapping, or concatenate the existing Zarr-backed arrays with the remaining in-memory chunks before the vertical merge.
zarr_group = None
if canvas_zarr is not None:
# Wrap zarr in dask array
canvas = da.from_zarr(canvas_zarr, chunks=canvas_zarr.chunks)
count = da.from_zarr(count_zarr, chunks=count_zarr.chunks)
zarr_group = zarr.open(canvas_zarr.store.path, mode="a")
tiatoolbox/models/engine/semantic_segmentor.py:1091
merged_shapeis computed usingbatch_xs/batch_xe, butmerge_batch_to_canvasstill uses the original globalxs:xevalues fromoutput_locs_when writing into the local row canvas. Ifbatch_xsis non-zero (e.g., sparse/masked rows), this will either index out of bounds or place blocks at the wrong x-offset. Adjustoutput_locs_to be relative tobatch_xs(and ensure downstream vertical merging accounts for any x-offset), or keep a consistent global-width canvas per row.
# Compute span only for the current row to avoid allocating a canvas
# covering the entire slide width.
batch_xs = np.min(output_locs_[:, 0], axis=0)
batch_xe = np.max(output_locs_[:, 2], axis=0)
merged_shape = (canvas_np_.shape[1], batch_xe - batch_xs, canvas_np.shape[3])
canvas_merge, count_merge = merge_batch_to_canvas(
blocks=canvas_np_,
output_locations=output_locs_,
merged_shape=merged_shape,
)
tiatoolbox/models/engine/semantic_segmentor.py:1438
prepare_full_batchnow returns a Zarr array (zarr.zeros(...)) but the function signature/docstring still claim it returnsnp.ndarray, and the caller later passes this into utilities that expect NumPy/Dask arrays. This can break concatenation and/or force full materialization via__array__. Either return a concrete NumPy array (only when safe) or return a Dask array (e.g.,da.from_zarr) and update the downstream merge/concatenation path to handle it consistently.
store = zarr.DirectoryStore(str(temp_dir))
full_batch_output = zarr.zeros(
shape=(total_size, *sample_shape),
chunks=(len(batch_output), *sample_shape),
dtype=batch_output.dtype,
store=store,
overwrite=True,
)
# Place matching outputs using matching indices
full_batch_output[matches] = batch_output
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Initialize full output array | ||
| full_batch_output = np.zeros( | ||
| shape=(total_size, *batch_output.shape[1:]), | ||
| store = zarr.DirectoryStore(str(temp_dir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be slow when you save to zarr instead of numpy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually the main cause of the memory issue.
File "/media/u1910100/data/GitHub/tiatoolbox/tiatoolbox/models/engine/semantic_segmentor.py", line 1381, in prepare_full_batch
full_batch_output = np.zeros(
^^^^^^^^^
numpy._core._exceptions._ArrayMemoryError: Unable to allocate 33.0 GiB for an array with shape (6762, 512, 512, 5) and data type float32This is the log when the program crashed. The numpy array is too large. We can't use numpy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not check the size of the array before doing this? if the array is small in memory, it should just use numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can. I'll change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pad_len = len(full_output_locs) | ||
| if not use_numpy: | ||
| # Resize zarr array to accommodate padding | ||
| full_batch_output.resize(total_size + pad_len, *sample_shape) |
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zarr.Array.resize expects the new shape as a single tuple. Calling full_batch_output.resize(total_size + pad_len, *sample_shape) will raise a TypeError (in Zarr v2) and break the last-batch padding path when use_numpy is false. Pass a single shape tuple instead (consistent with other resize calls in this file).
| full_batch_output.resize(total_size + pad_len, *sample_shape) | |
| full_batch_output.resize((total_size + pad_len, *sample_shape)) |
| # Calculate final size including potential padding | ||
| final_size = total_size | ||
| if is_last and len(full_output_locs): | ||
| final_size += len(full_output_locs) |
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final_size is computed using len(full_output_locs) before full_output_locs is sliced (full_output_locs = full_output_locs[total_size:]). For the last batch this overestimates the required padding and can significantly inflate the RAM size estimate / preallocated NumPy array. Compute padding based on the remaining full_output_locs after slicing (or compute pad_len = max(len(full_output_locs) - total_size, 0) up-front).
| # Calculate final size including potential padding | |
| final_size = total_size | |
| if is_last and len(full_output_locs): | |
| final_size += len(full_output_locs) | |
| # Calculate final size including potential padding for remaining locations | |
| pad_len = 0 | |
| if is_last and len(full_output_locs): | |
| # Remaining locations after consuming indices up to total_size | |
| pad_len = max(len(full_output_locs) - total_size, 0) | |
| final_size = total_size + pad_len |
| else: | ||
| # Array too large, use zarr backed by disk to avoid RAM spikes | ||
| # Use a unique temp subdirectory per call to avoid chunk-shape clashes | ||
| msg = ( | ||
| f"Estimated peak memory usage for full batch output: " | ||
| f"{peak_bytes / (1024**3):.2f} GB exceeds threshold of " | ||
| f"{memory_available / (1024**3):.2f} GB." | ||
| f"Allocating full batch output of size " | ||
| f"{final_size}x{sample_shape} using Zarr on disk." | ||
| ) | ||
| logger.info(msg) | ||
|
|
||
| total_size = np.max(matches).astype(np.uint32) + 1 | ||
| save_path_dir = Path(save_path) | ||
| save_path_dir.mkdir(parents=True, exist_ok=True) | ||
| temp_dir = Path( | ||
| tempfile.mkdtemp(prefix="full_batch_tmp_", dir=str(save_path_dir)) | ||
| ) | ||
|
|
||
| # Initialize full output array | ||
| full_batch_output = np.zeros( | ||
| shape=(total_size, *batch_output.shape[1:]), | ||
| dtype=batch_output.dtype, | ||
| ) | ||
| store = zarr.DirectoryStore(str(temp_dir)) | ||
| full_batch_output = zarr.zeros( | ||
| shape=(total_size, *sample_shape), | ||
| chunks=(len(batch_output), *sample_shape), | ||
| dtype=batch_output.dtype, | ||
| store=store, | ||
| overwrite=True, | ||
| ) |
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New behavior switches prepare_full_batch between in-memory NumPy and disk-backed Zarr based on available memory and also writes temp stores under full_batch_tmp. There isn’t a focused unit test covering the Zarr-backed branch (including correct shape/padding behavior and that temp directories get cleaned up). Consider adding a test that forces low available memory (mock psutil.virtual_memory().available) to exercise the Zarr path and verifies the returned array contents + cleanup.
This pull request introduces significant improvements to memory management and efficiency in the semantic segmentation engine, especially for large whole-slide image (WSI) processing. The main changes focus on incremental processing and disk-backed storage to avoid excessive RAM usage, as well as more robust cleanup of temporary files. There are also adjustments to test tolerances and some bug fixes in array handling.
Memory management and efficiency improvements:
prepare_full_batchfunction now dynamically decides whether to use in-memory NumPy arrays or disk-backed Zarr arrays for large batch outputs, based on available system memory and a configurable threshold. This prevents memory spikes when processing large WSIs. [1] [2] [3] [4]save_to_cachefunction has been refactored to incrementally write Dask array blocks to Zarr on disk, avoiding materializing large arrays in memory and reducing peak RAM usage. [1] [2]infer_wsinow use up-to-date available memory rather than an initial snapshot, and intermediate results are spilled to disk when thresholds are exceeded. [1] [2] [3]Robustness and cleanup:
Bug fixes and test adjustments:
merge_batch_to_canvasto ensure compatibility with both NumPy and Dask arrays.merge_horizontalto compute spans and concatenate outputs only for the current row, improving correctness and efficiency. [1] [2]Previous Problem:
I encountered out-of-memory issues and Python kept crashing when processing a relatively large WSI.
The example slide I was trying to run was:
https://huggingface.co/datasets/TIACentre/TIAToolBox_Remote_Samples/blob/main/sample_wsis/D_P000019_PAS_CPG.tif. The code I was trying to run was:Before this PR, the code kept crashing on my workstation, which has 32GBs of RAM, memory spiked to 100% just before it crashed.