Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/neuron_proofreader/merge_proofreading/merge_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def load_fragment_graphs(
# Remove groundtruth skeletons
for swc_id in graph.swc_ids():
if swc_id.lower().startswith("n"):
component_id = util.find_key(
graph.component_id_to_swc_id, swc_id
)
component_id = graph.component_id_from_swc_id(swc_id)
nodes = graph.nodes_with_component_id(component_id)
graph.remove_nodes(nodes, relabel_nodes=False)

Expand Down
57 changes: 39 additions & 18 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

"""

from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from torch.utils.data import IterableDataset
from tqdm import tqdm

Expand Down Expand Up @@ -42,6 +44,7 @@ def __init__(
config,
gt_path=None,
metadata_path=None,
prefetch=4,
segmentation_path=None,
soma_centroids=set(),
):
Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
# Instance attributes
self.config = config
self.gt_path = gt_path
self.prefetch = prefetch
self.transform = ImageTransforms() if config.ml.transform else False

# Build graph
Expand Down Expand Up @@ -174,6 +178,7 @@ def add_dataset(
config,
gt_path=None,
metadata_path=None,
prefetch=4,
segmentation_path=None,
soma_centroids=list(),
):
Expand All @@ -195,6 +200,8 @@ def add_dataset(
metadata_path : str, optional
Patch to JSON file containing metadata on block that fragments
were extracted from. Default is None.
prefetch : int, optional
Number of batches to prefetch. Default is 4.
segmentation_path : str, optional
Path to the segmentation that fragments were generated from.
Default is None.
Expand All @@ -208,6 +215,7 @@ def add_dataset(
config,
gt_path=gt_path,
metadata_path=metadata_path,
prefetch=prefetch,
segmentation_path=segmentation_path,
soma_centroids=soma_centroids,
)
Expand All @@ -223,21 +231,39 @@ def __iter__(self):
targets : torch.Tensor
Ground truth labels.
"""
# Initializations
samplers = self.init_samplers()
while len(samplers) > 0:
key = self.get_next_key(samplers)
try:
# Extract features
subgraph = next(samplers[key])
queue = Queue(maxsize=self.prefetch * len(samplers))
active_keys = set(samplers.keys())

# Launch one prefetch thread per dataset
with ThreadPoolExecutor(max_workers=len(samplers)) as executor:
for key, sampler in samplers.items():
executor.submit(self._worker, key, sampler, queue)

# Consume from queue until all datasets exhausted
while active_keys:
key, inputs, targets = queue.get()
if inputs is StopIteration:
active_keys.discard(key)
continue
if isinstance(inputs, Exception):
raise inputs
yield inputs, targets

def _worker(self, key, sampler, queue):
"""
Runs in a background thread, prefetches extracted features into queue.
"""
try:
for subgraph in sampler:
features = self.datasets[key].feature_extractor(subgraph)
data = HeteroGraphData(features)

# Get training inputs
inputs = data.get_inputs()
targets = data.get_targets()
yield inputs, targets
except StopIteration:
del samplers[key]
queue.put((key, data.get_inputs(), data.get_targets()))
except Exception as e:
queue.put((key, e, None))
finally:
queue.put((key, StopIteration, None))

def generate_proposals(self, search_radius):
"""
Expand All @@ -248,16 +274,11 @@ def generate_proposals(self, search_radius):
search_radius : float
Search radius used to generate proposals.
"""
# Proposal generation
for key in tqdm(self.datasets, desc="Generate Proposals"):
self.datasets[key].graph.generate_proposals(
search_radius, allow_nonleaf_proposals=True
)

# Report results
print("# Proposals:", self.n_proposals())
print("% Accepts:", self.p_accepts())

# --- Helpers ---
def __len__(self):
"""
Expand Down Expand Up @@ -333,7 +354,7 @@ def p_accepts(self):
accepts_cnt = 0
for dataset in self.datasets.values():
accepts_cnt += len(dataset.graph.gt_accepts)
return accepts_cnt / (self.n_proposals() + 1e-5)
return 100 * accepts_cnt / (self.n_proposals() + 1e-5)

def save_examples_summary(self, path):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,23 +356,28 @@ def read_segmentation(self, center, shape):

def compute_crop(self, proposal):
"""
Extracts an intensity profile along a set of voxel coordinates.
Computes the center and cubic shape of the image patch for a proposal.

Parameters
----------
proposal : Frozenset[int]
Proposal to compute image crop of.

Returns
-------
profile : numpy.ndarray
Image with shape (2, H, W, D) containing a raw image and proposal
mask channels.
center : Tuple[int]
Center of the bounding box between the two proposal nodes.
shape : Tuple[int]
Cubic shape large enough to contain both nodes with padding.
"""
# Get info
node1, node2 = tuple(proposal)
voxel1 = self.graph.node_voxel(node1)
voxel2 = self.graph.node_voxel(node2)
# Node info
node1, node2 = proposal
voxel1 = np.array(self.graph.node_voxel(node1))
voxel2 = np.array(self.graph.node_voxel(node2))

# Compute bounds
bounds = img_util.get_minimal_bbox([voxel1, voxel2], self.padding)
center = tuple([int((v1 + v2) / 2) for v1, v2 in zip(voxel1, voxel2)])
length = np.max([u - l for u, l in zip(bounds["max"], bounds["min"])])
center = tuple(((voxel1 + voxel2) / 2).astype(int))
length = np.max(np.abs(voxel2 - voxel1)) + 2 * self.padding
return center, (length, length, length)


Expand Down Expand Up @@ -455,11 +460,6 @@ def get_intensity_profile(self):
profile = np.concatenate(
(branch1_profile, proposal_profile, branch2_profile)
)

# Adjust intensities
max_val = np.max(profile) + 1e-5
self.img = np.minimum(max_val, self.img) / (max_val + 1e-5)
profile /= (max_val + 1e-5)
return profile

def get_branch_profile(self, node):
Expand Down Expand Up @@ -496,8 +496,8 @@ def _extract_profile(self, voxels):
Image with shape (2, H, W, D) containing a raw image and proposal
mask channels.
"""
voxels = check_list_length(voxels, min_length=16)
profile = np.array([self.img[tuple(voxel)] for voxel in voxels])
voxels = np.asarray(check_list_length(voxels, min_length=16))
profile = self.img[voxels[:, 0], voxels[:, 1], voxels[:, 2]]
profile = np.append(profile, [profile.mean(), profile.std()])
return profile

Expand Down
24 changes: 0 additions & 24 deletions src/neuron_proofreader/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,30 +443,6 @@ def get_contained_voxels(voxels, shape, buffer=0):
return [v for v in voxels if is_contained(v, shape, buffer)]


def get_minimal_bbox(voxels, buffer=0):
"""
Gets the min and max coordinates of a bounding box that contains "voxels".

Parameters
----------
voxels : numpy.ndarray
Array containing voxel coordinates.
buffer : int, optional
Constant value added/subtracted from the max/min coordinates of the
bounding box. Default is 0.

Returns
-------
bbox : Dict[str, numpy.ndarray]
Bounding box.
"""
bbox = {
"min": np.floor(np.min(voxels, axis=0) - buffer).astype(int),
"max": np.ceil(np.max(voxels, axis=0) + buffer).astype(int),
}
return bbox


def get_neighbors(voxel, shape):
"""
Gets the neighbors of a given voxel coordinate.
Expand Down
Loading