Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 92 additions & 20 deletions cellmap_flow/post/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
from scipy.ndimage import label
import mwatershed as mws
from scipy.ndimage import measurements
from scipy.ndimage import measurements, gaussian_filter
import fastremap
from funlib.math import cantor_number
import fastmorph
Expand Down Expand Up @@ -131,7 +131,9 @@ def is_segmentation(self):
class AffinityPostprocessor(PostProcessor):
def __init__(
self,
bias: float = 0.0,
adjacent_edge_bias: float = -0.4,
lr_bias_ratio: float = -0.175,
filter_val: float = 0.5,
neighborhood: str = """[
[1, 0, 0],
[0, 1, 0],
Expand All @@ -145,36 +147,106 @@ def __init__(
]""",
):
use_exact = "True"
self.bias = float(bias)
self.adjacent_edge_bias = float(adjacent_edge_bias)
self.lr_bias_ratio = float(lr_bias_ratio)
self.filter_val = float(filter_val)
self.neighborhood = ast.literal_eval(neighborhood)
self.use_exact = use_exact == "True"
self.use_exact = use_exact == "False"
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic appears inverted - setting use_exact = use_exact == "False" means use_exact will be True when the string is "False" and False otherwise. This seems counterintuitive and likely incorrect.

Suggested change
self.use_exact = use_exact == "False"
self.use_exact = use_exact == "True"

Copilot uses AI. Check for mistakes.
self.num_previous_segments = 0

def _process(self, data, chunk_num_voxels, chunk_corner):
data = data / 255.0
n_channels = data.shape[0]
self.neighborhood = self.neighborhood[:n_channels]
# raise Exception(data.max(), data.min(), self.neighborhood)
import numpy as np
from scipy.ndimage import measurements
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be at the top of the file, not inside a method. This import is also redundant as measurements is already imported at the top.

Suggested change
from scipy.ndimage import measurements

Copilot uses AI. Check for mistakes.

Comment on lines +157 to 159
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be at the top of the file, not inside a method. Move this import to the top of the file with other imports.

Suggested change
import numpy as np
from scipy.ndimage import measurements

Copilot uses AI. Check for mistakes.
segmentation = mws.agglom(
data.astype(np.float64) - self.bias,
self.neighborhood,
)
def filter_fragments(
self, affs_data: np.ndarray, fragments_data: np.ndarray, filter_val: float
) -> None:
"""Allows filtering of MWS fragments based on mean value of affinities & fragments. Will filter and update the fragment array in-place.

# filter fragments
average_affs = np.mean(data, axis=0)
Args:
aff_data (``np.ndarray``):
An array containing affinity data.

fragments_data (``np.ndarray``):
An array containing fragment data.

filter_val (``float``):
Threshold to filter if the average value falls below.
"""

filtered_fragments = []
average_affs: float = np.mean(affs_data.data, axis=0)
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing .data attribute on affs_data may fail if affs_data is a numpy array rather than an object with a .data attribute. Use affs_data directly instead of affs_data.data.

Suggested change
average_affs: float = np.mean(affs_data.data, axis=0)
average_affs: float = np.mean(affs_data, axis=0)

Copilot uses AI. Check for mistakes.

fragment_ids = fastremap.unique(segmentation[segmentation > 0])
filtered_fragments: list = []

fragment_ids: np.ndarray = np.unique(fragments_data)

for fragment, mean in zip(
fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids)
fragment_ids, measurements.mean(average_affs, fragments_data, fragment_ids)
):
if mean >= self.bias:
if mean < filter_val:
filtered_fragments.append(fragment)

fastremap.mask_except(segmentation, filtered_fragments, in_place=True)
filtered_fragments: np.ndarray = np.array(
filtered_fragments, dtype=fragments_data.dtype
)
# replace: np.ndarray = np.zeros_like(filtered_fragments)
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code that is not being used. This line appears to be leftover development code.

Suggested change
# replace: np.ndarray = np.zeros_like(filtered_fragments)

Copilot uses AI. Check for mistakes.
fastremap.mask(fragments_data, filtered_fragments, in_place=True)

def _process(self, data, chunk_num_voxels, chunk_corner):
data[data < self.filter_val] = 0
if data.dtype == np.uint8:
logger.info("Assuming affinities are in [0,255]")
max_affinity_value: float = 255.0
data = data.astype(np.float64)
else:
data = data.astype(np.float64)
max_affinity_value: float = 1.0

data /= max_affinity_value

if data.max() < 1e-4:
segmentation = np.zeros(
data.shape, dtype=np.uint64 if self.use_exact else np.uint16
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape for the zeros array should match the spatial dimensions of the data, but data.shape includes the channel dimension. Use data.shape[1:] instead of data.shape to exclude the channel dimension.

Suggested change
data.shape, dtype=np.uint64 if self.use_exact else np.uint16
data.shape[1:], dtype=np.uint64 if self.use_exact else np.uint16

Copilot uses AI. Check for mistakes.
)
return np.expand_dims(segmentation, axis=0)

channels = [
channel for channel, ntp in enumerate(self.neighborhood) if ntp is not None
]
neighborhood = [self.neighborhood[channel] for channel in channels]

data = data[channels]
random_noise: float = np.random.randn(*data.shape) * 0.0001
smoothed_affs: np.ndarray = (
gaussian_filter(data, sigma=(0, *(np.amax(neighborhood, axis=0) / 3))) - 0.5
) * 0.001
shift: np.ndarray = np.array(
[
(
self.adjacent_edge_bias
if max(offset) <= 1
else np.linalg.norm(offset) * self.lr_bias_ratio
)
for offset in neighborhood
]
).reshape((-1, *((1,) * (len(data.shape) - 1))))

# raise Exception(data.max(), data.min(), self.neighborhood)

# segmentation = mws.agglom(
# data.astype(np.float64) - self.bias,
# self.neighborhood,
# )

# filter fragments
segmentation = mws.agglom(
data + shift + random_noise + smoothed_affs,
offsets=neighborhood,
)
if self.filter_val > 0.0:
self.filter_fragments(data, segmentation, self.filter_val)

# fragment_ids = fastremap.unique(segmentation[segmentation > 0])
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code. These lines appear to be leftover from the previous implementation.

Suggested change
# fragment_ids = fastremap.unique(segmentation[segmentation > 0])

Copilot uses AI. Check for mistakes.
# fastremap.mask_except(segmentation, filtered_fragments, in_place=True)
fastremap.renumber(segmentation, in_place=True)
unique_increment = chunk_num_voxels * pymorton.interleave(*chunk_corner)
if not self.use_exact:
Expand Down
5 changes: 3 additions & 2 deletions cellmap_flow/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _get_config(self):
config.output_channels = len(
config.channels
) # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld
config.block_shape = np.array(tuple(out_shape) + (len(channels),))
config.block_shape = np.array(tuple(out_shape) + (config.output_channels,))

return config

Expand Down Expand Up @@ -384,7 +384,8 @@ def get_dacapo_channels(task):
if hasattr(task, "channels"):
return task.channels
elif type(task).__name__ == "AffinitiesTask":
return ["x", "y", "z"]
# to be backwards compatible in case .channels or .neighborhood doesn't exist
return [f"aff_{'.'.join(map(str, n))}" for n in task.predictor.neighborhood]
else:
return ["membrane"]

Expand Down