Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 49 additions & 33 deletions src/neuron_proofreader/merge_proofreading/merge_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import networkx as nx
import numpy as np
import os
import pandas as pd
import torch

from neuron_proofreader.machine_learning.point_cloud_models import (
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
# Instance attributes
self.dataset = dataset
self.device = device
self.node_preds = np.ones((len(dataset.graph.node_xyz))) * 1e-2
self.node_preds = np.ones((len(dataset.node_xyz))) * 1e-2
self.patch_shape = dataset.patch_shape
self.remove_detected_sites = remove_detected_sites
self.threshold = threshold
Expand Down Expand Up @@ -111,7 +112,7 @@ def filter_with_nms(self, merge_sites, likelihoods):
while merge_sites:
# Local max
root = merge_sites.pop()
xyz_root = self.dataset.graph.node_xyz[root]
xyz_root = self.dataset.node_xyz[root]
if root in merge_sites_set:
filtered_merge_sites.add(root)
merge_sites_set.remove(root)
Expand All @@ -125,7 +126,7 @@ def filter_with_nms(self, merge_sites, likelihoods):
# Visit node
i, dist_i = queue.pop()
if i in merge_sites_set:
xyz_i = self.dataset.graph.node_xyz[i]
xyz_i = self.dataset.node_xyz[i]
iou = img_util.compute_iou3d(
xyz_i, xyz_root, self.patch_shape, self.patch_shape
)
Expand All @@ -134,8 +135,8 @@ def filter_with_nms(self, merge_sites, likelihoods):
self.node_preds[i] = 1e-2

# Populate queue
for j in self.dataset.graph.neighbors(i):
dist_j = dist_i + self.dataset.graph.dist(i, j)
for j in self.dataset.neighbors(i):
dist_j = dist_i + self.dataset.dist(i, j)
if j not in visited and dist_j < self.patch_shape[0]:
queue.append((j, dist_j))
visited.add(j)
Expand All @@ -145,26 +146,26 @@ def remove_merge_sites(self, merge_site_nodes, max_depth=10):
rm_nodes = set()
for root in tqdm(merge_site_nodes, desc="Remove Merge Sites"):
# Extract neighborhood
root = self.dataset.graph.find_nearby_branching_node(root)
nbhd = self.dataset.graph.nodes_within_distance(root, max_depth)
root = self.dataset.find_nearby_branching_node(root)
nbhd = self.dataset.nodes_within_distance(root, max_depth)

# Check for branching node in neighborhood
for i in list(nbhd):
if i != root and self.dataset.graph.degree[i] >= 3:
nbhd_i = self.dataset.graph.nodes_within_distance(root, 8)
if i != root and self.dataset.degree[i] >= 3:
nbhd_i = self.dataset.nodes_within_distance(root, 8)
nbhd.extend(nbhd_i)

# Add nodes to removal list
rm_nodes.update(set(nbhd))

# Update graph
self.dataset.graph.remove_nodes(rm_nodes)
self.dataset.remove_nodes(rm_nodes)
print("# Nodes Deleted:", len(rm_nodes))

# --- Helpers ---
def get_detected_sites(self, threshold):
nodes = np.where(self.node_preds >= threshold)[0]
return [self.dataset.graph.node_xyz[i] for i in nodes]
return [self.dataset.node_xyz[i] for i in nodes]

def save_parameters(self, output_dir):
json_path = os.path.join(output_dir, "detection_parameters.json")
Expand All @@ -185,17 +186,26 @@ def save_results(
self.save_sites(output_dir)
if save_fragments:
fragments_path = os.path.join(output_dir, "fragments.zip")
self.dataset.graph.to_zipped_swcs(fragments_path)
self.dataset.to_zipped_swcs(fragments_path)

# Upload results to S3 (if applicable)
if output_prefix_s3:
bucket_name, prefix = util.parse_cloud_path(output_prefix_s3)
util.upload_dir_to_s3(output_dir, bucket_name, prefix)

def save_sites(self, output_dir):
# Save model predictions
df = pd.DataFrame(columns=["World", "Segment_ID", "Prediction"])
df["World"] = self.dataset.node_xyz
df["Prediction"] = self.node_preds
df["Segment_ID"] = [
self.dataset.node_segment_id(i) for i in self.dataset.nodes
]
df.to_csv(os.path.join(output_dir, "model_predictions.csv"))

# Get predicted merge sites
nodes = np.where(self.node_preds >= self.threshold)[0]
detected_sites = [self.dataset.graph.node_xyz[i] for i in nodes]
detected_sites = [self.dataset.node_xyz[i] for i in nodes]
print("# Sites Saved:", len(nodes))

# Save predicted merge sites
Expand All @@ -213,14 +223,14 @@ def save_train_dataset(self, output_dir):
roots = list()
visited_ids = set()
for i in np.where(self.node_preds >= self.threshold)[0]:
cc_id = self.dataset.graph.node_component_id[i]
cc_id = self.dataset.node_component_id[i]
if cc_id not in visited_ids:
roots.append([i])
visited_ids.add(cc_id)

# Save fragments
zip_path = os.path.join(output_dir, "fragments.zip")
self.dataset.graph._batch_to_zipped_swcs(roots, zip_path, False)
self.dataset._batch_to_zipped_swcs(roots, zip_path, False)
self.save_sites(output_dir)
print("# Fragments Saved:", len(roots))

Expand Down Expand Up @@ -279,7 +289,7 @@ def __iter__(self):
# Search graph
visited_ids = set()
for u in self.graph.leaf_nodes():
component_id = self.graph.node_component_id[u]
component_id = self.node_component_id[u]
if component_id not in visited_ids and component_id in valid_ids:
visited_ids.add(component_id)
yield from self._generate_batches_from_component(u)
Expand Down Expand Up @@ -311,11 +321,11 @@ def find_fragments_to_search(self):

# Check if path length satisfies threshold
if length > self.min_size:
component_ids.add(self.graph.node_component_id[node])
component_ids.add(self.node_component_id[node])
return component_ids

def get_patch_centers(self, nodes):
patch_centers = [self.graph.node_voxel(i) for i in nodes]
patch_centers = [self.node_voxel(i) for i in nodes]
return np.array(patch_centers, dtype=int)

def get_label_mask(self, nodes, img_shape, offset):
Expand All @@ -331,8 +341,8 @@ def get_label_mask(self, nodes, img_shape, offset):
# Annotate mask
subgraph = self.get_contained_subgraph(nodes, img_shape, offset)
for i, j in subgraph.edges:
voxel_i = self.graph.node_voxel(i) - offset
voxel_j = self.graph.node_voxel(j) - offset
voxel_i = self.node_voxel(i) - offset
voxel_j = self.node_voxel(j) - offset
voxels = geometry_util.make_digital_line(voxel_i, voxel_j)
img_util.annotate_voxels(segment_mask, voxels)
return segment_mask
Expand All @@ -344,13 +354,13 @@ def get_contained_subgraph(self, nodes, img_shape, offset):
while queue:
# Visit node
i = queue.pop()
voxel_i = self.graph.node_voxel(i) - offset
voxel_i = self.node_voxel(i) - offset
if not img_util.is_contained(voxel_i, img_shape, buffer=1):
continue

# Update queue
for j in self.graph.neighbors(i):
voxel_j = self.graph.node_voxel(j) - offset
for j in self.neighbors(i):
voxel_j = self.node_voxel(j) - offset
if img_util.is_contained(voxel_j, img_shape):
subgraph.add_edge(i, j)
if j not in visited:
Expand All @@ -359,7 +369,7 @@ def get_contained_subgraph(self, nodes, img_shape, offset):
return subgraph

def is_contained(self, node):
voxel = self.graph.node_voxel(node)
voxel = self.node_voxel(node)
shape = self.img_reader.shape()[2::]
buffer = np.max(self.patch_shape) + 1
return img_util.is_contained(voxel, shape, buffer=buffer)
Expand All @@ -380,7 +390,7 @@ def read_superchunk(self, nodes):

def is_near_leaf(self, node, threshold=20):
# Check if node is branching
if self.graph.degree[node] > 2:
if self.degree[node] > 2:
return False

# Search neighborhood
Expand All @@ -389,12 +399,12 @@ def is_near_leaf(self, node, threshold=20):
while len(queue) > 0:
# Visit node
i, dist_i = queue.pop()
if self.graph.degree[i] == 1:
if self.degree[i] == 1:
return True

# Update queue
for j in self.graph.neighbors(i):
dist_j = dist_i + self.graph.dist(i, j)
for j in self.neighbors(i):
dist_j = dist_i + self.dist(i, j)
if j not in visited and dist_j < threshold:
queue.append((j, dist_j))
visited.add(j)
Expand Down Expand Up @@ -489,7 +499,7 @@ def _generate_batch_nodes(self, root):
nodes = list()
for i, j in nx.dfs_edges(self.graph, source=root):
# Check if starting new batch
self.distance_traversed += self.graph.dist(i, j)
self.distance_traversed += self.dist(i, j)
if len(nodes) == 0:
if self.is_node_valid(i):
root = i
Expand All @@ -499,7 +509,7 @@ def _generate_batch_nodes(self, root):
continue

# Check whether to yield batch
is_node_far = self.graph.dist(root, j) > 512
is_node_far = self.dist(root, j) > 512
is_batch_full = len(nodes) == self.batch_size
if is_node_far or is_batch_full:
# Yield nodes in batch
Expand All @@ -509,8 +519,8 @@ def _generate_batch_nodes(self, root):
nodes = list()

# Visit j
is_next = self.graph.dist(last_node, j) >= self.step_size - 2
is_branching = self.graph.degree[j] >= 3
is_next = self.dist(last_node, j) >= self.step_size - 2
is_branching = self.degree[j] >= 3
if (is_next or is_branching) and self.is_node_valid(j):
last_node = j
nodes.append(j)
Expand Down Expand Up @@ -560,6 +570,9 @@ def _get_multimodal_batch(self, nodes, img, offset):
return nodes, batch

# --- Helpers ---
def __getattr__(self, name):
return getattr(self.graph, name)

def estimate_iterations(self):
"""
Estimates the number of iterations required to search graph.
Expand All @@ -573,7 +586,7 @@ def estimate_iterations(self):
total_cable_length = 0
n_fragments = 0
for nodes in map(list, nx.connected_components(self.graph)):
cable_length = self.graph.cable_length(root=nodes[0])
cable_length = self.cable_length(root=nodes[0])
if cable_length > self.min_size:
total_cable_length += cable_length
n_fragments += 1
Expand Down Expand Up @@ -651,6 +664,9 @@ def _generate_batch_nodes(self, root):
root = j

# --- Helpers ---
def __getattr__(self, name):
return getattr(self.graph, name)

def estimate_iterations(self):
"""
Estimates the number of iterations required to search graph.
Expand Down
Loading