Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
config,
gt_path=None,
metadata_path=None,
segmentation_path=None,
soma_centroids=set(),
):
"""
Expand All @@ -61,9 +60,6 @@ def __init__(
metadata_path : str, optional
Patch to JSON file containing metadata on block that fragments
were extracted from. Default is None.
segmentation_path : str, optional
Path to the segmentation that fragments were generated from.
Default is None.
soma_centroids : List[Tuple[int]], optional
Phyiscal coordinates of soma centroids. Default is an empty list.
"""
Expand All @@ -82,7 +78,6 @@ def __init__(
img_path,
brightness_clip=self.config.ml.brightness_clip,
patch_shape=self.config.ml.patch_shape,
segmentation_path=segmentation_path,
)

def _load_graph(self, fragments_path, metadata_path=None):
Expand Down Expand Up @@ -174,7 +169,6 @@ def add_dataset(
config,
gt_path=None,
metadata_path=None,
segmentation_path=None,
soma_centroids=list(),
):
"""
Expand All @@ -195,9 +189,6 @@ def add_dataset(
metadata_path : str, optional
Patch to JSON file containing metadata on block that fragments
were extracted from. Default is None.
segmentation_path : str, optional
Path to the segmentation that fragments were generated from.
Default is None.
soma_centroids : List[Tuple[int]], optional
Phyiscal coordinates of soma centroids. Default is an empty list.
"""
Expand All @@ -208,7 +199,6 @@ def add_dataset(
config,
gt_path=gt_path,
metadata_path=metadata_path,
segmentation_path=segmentation_path,
soma_centroids=soma_centroids,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from concurrent.futures import as_completed, ThreadPoolExecutor
from scipy.spatial import KDTree

from skimage.transform import resize
from torch_geometric.data import HeteroData

Expand All @@ -34,7 +35,6 @@ def __init__(
brightness_clip=400,
padding=50,
patch_shape=(96, 96, 96),
segmentation_path=None,
):
"""
Instantiates a FeaturePipeline object.
Expand All @@ -53,8 +53,6 @@ def __init__(
patch_shape : Tuple[int], optional
Shape of image patch expected by the vision model. Default is (96,
96, 96).
segmentation_path : str, optional
Path to segmentation of whole-brain dataset.
"""
self.extractors = [
SkeletonFeatureExtractor(graph),
Expand All @@ -64,8 +62,7 @@ def __init__(
brightness_clip=brightness_clip,
patch_shape=patch_shape,
padding=padding,
segmentation_path=segmentation_path,
),
),
]

def __call__(self, subgraph):
Expand Down Expand Up @@ -226,7 +223,6 @@ def __init__(
brightness_clip=400,
patch_shape=(96, 96, 96),
padding=40,
segmentation_path=None,
):
"""
Instantiates an ImageExtractor object.
Expand All @@ -245,8 +241,6 @@ def __init__(
padding : int, optional
Number of voxels to be added in each dimension from start and end
point of proposal for image patch extraction. Default is 40.
segmentation_path : str, optional
Path to segmentation of whole-brain dataset.
"""
# Instance attributes
self.brightness_clip = brightness_clip
Expand All @@ -256,10 +250,6 @@ def __init__(

# Image reader
self.img = img_util.TensorStoreReader(img_path)
if segmentation_path:
self.segmentation = img_util.TensorStoreReader(segmentation_path)
else:
self.segmentation = None

def __call__(self, subgraph, features):
"""
Expand Down Expand Up @@ -304,16 +294,16 @@ def init_extractor(self, proposal):
Returns
-------
extractor : PatchFeatureExtractor
Feature extractor configured with the cropped image, segmentation
mask, spatial offset, and patch shape.
Feature extractor configured with the cropped image, segment mask,
spatial offset, and patch shape.
"""
# Compute patch specs
center, shape = self.compute_crop(proposal)
offset = img_util.get_offset(center, shape)

# Read images
img = self.read_image(center, shape)
mask = self.read_segmentation(center, shape)
mask = self.create_segment_mask(proposal, img.shape, offset)

# Create patch feature extractor
extractor = PatchFeatureExtractor(
Expand All @@ -322,6 +312,25 @@ def init_extractor(self, proposal):
return extractor

# --- Helpers ---
def create_segment_mask(self, proposal, shape, offset):
# Find nearby nodes
center = self.graph.proposal_midpoint(proposal)
nodes = self.graph.kdtree.query_ball_point(center, self.padding + 10)

# Populate mask
mask = np.zeros(shape)
visited = set()
for i in nodes:
voxel_i = self.graph.node_local_voxel(i, offset)
for j in self.graph.neighbors(i):
if frozenset({i, j}) not in visited and j in nodes:
voxel_j = self.graph.node_local_voxel(j, offset)
voxels = geometry_util.make_digital_line(voxel_i, voxel_j)
img_util.annotate_voxels(mask, voxels, val=0.25)
visited.add(frozenset({i, j}))
return mask


def read_image(self, center, shape):
"""
Reads the image patch specified by the given center and shape.
Expand All @@ -337,23 +346,6 @@ def read_image(self, center, shape):
patch = np.minimum(patch, self.brightness_clip)
return img_util.normalize(patch)

def read_segmentation(self, center, shape):
"""
Reads the segmentation patch specified by the given center and shape.

Parameters
----------
center : Tuple[int]
Center of segmentation patch to be read.
shape : Tuple[int]
Center of segmentation patch to be read.
"""
if self.segmentation:
patch = self.segmentation.read(center, shape)
return 0.25 * (patch > 0).astype(float)
else:
return np.zeros(shape)

def compute_crop(self, proposal):
"""
Extracts an intensity profile along a set of voxel coordinates.
Expand All @@ -364,15 +356,14 @@ def compute_crop(self, proposal):
Image with shape (2, H, W, D) containing a raw image and proposal
mask channels.
"""
# Get info
node1, node2 = tuple(proposal)
voxel1 = self.graph.node_voxel(node1)
voxel2 = self.graph.node_voxel(node2)
# Node info
node1, node2 = proposal
voxel1 = np.array(self.graph.node_voxel(node1))
voxel2 = np.array(self.graph.node_voxel(node2))

# Compute bounds
bounds = img_util.get_minimal_bbox([voxel1, voxel2], self.padding)
center = tuple([int((v1 + v2) / 2) for v1, v2 in zip(voxel1, voxel2)])
length = np.max([u - l for u, l in zip(bounds["max"], bounds["min"])])
center = tuple(((voxel1 + voxel2) / 2).astype(int))
length = np.max(np.abs(voxel2 - voxel1)) + 2 * self.padding
return center, (length, length, length)


Expand Down Expand Up @@ -455,11 +446,6 @@ def get_intensity_profile(self):
profile = np.concatenate(
(branch1_profile, proposal_profile, branch2_profile)
)

# Adjust intensities
max_val = np.max(profile) + 1e-5
self.img = np.minimum(max_val, self.img) / (max_val + 1e-5)
profile /= (max_val + 1e-5)
return profile

def get_branch_profile(self, node):
Expand Down
11 changes: 2 additions & 9 deletions src/neuron_proofreader/split_proofreading/split_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
model,
config,
log_preamble="",
segmentation_path=None,
soma_centroids=list(),
):
"""
Expand All @@ -81,9 +80,6 @@ def __init__(
log_preamble : str, optional
String to be added to the beginning of log. Default is an empty
string.
segmentation_path : str, optional
Path to segmentation corresponding to the given fragments. Default
is None.
soma_centroids : List[Tuple[float]], optional
Physcial coordinates of soma centroids. Default is an empty list.
"""
Expand All @@ -102,9 +98,9 @@ def __init__(
self.log(log_preamble)

# Load data
self._load_data(fragments_path, img_path, segmentation_path)
self._load_data(fragments_path, img_path)

def _load_data(self, fragments_path, img_path, segmentation_path):
def _load_data(self, fragments_path, img_path):
"""
Builds a graph from the given fragments.

Expand All @@ -114,8 +110,6 @@ def _load_data(self, fragments_path, img_path, segmentation_path):
Path to SWC files to be loaded into graph.
img_path : str
Path to whole-brain image corresponding to the given fragments.
segmentation_path : str
Path to segmentation corresponding to the given fragments.
"""
# Load data
t0 = time()
Expand All @@ -124,7 +118,6 @@ def _load_data(self, fragments_path, img_path, segmentation_path):
fragments_path,
img_path,
self.config,
segmentation_path=segmentation_path,
soma_centroids=self.soma_centroids,
)
self.log(self.dataset.graph.summary(prefix="\nInitial"))
Expand Down
24 changes: 0 additions & 24 deletions src/neuron_proofreader/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,30 +443,6 @@ def get_contained_voxels(voxels, shape, buffer=0):
return [v for v in voxels if is_contained(v, shape, buffer)]


def get_minimal_bbox(voxels, buffer=0):
"""
Gets the min and max coordinates of a bounding box that contains "voxels".

Parameters
----------
voxels : numpy.ndarray
Array containing voxel coordinates.
buffer : int, optional
Constant value added/subtracted from the max/min coordinates of the
bounding box. Default is 0.

Returns
-------
bbox : Dict[str, numpy.ndarray]
Bounding box.
"""
bbox = {
"min": np.floor(np.min(voxels, axis=0) - buffer).astype(int),
"max": np.ceil(np.max(voxels, axis=0) + buffer).astype(int),
}
return bbox


def get_neighbors(voxel, shape):
"""
Gets the neighbors of a given voxel coordinate.
Expand Down
Loading