Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/neuron_proofreader/proposal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def generate_proposals(

def is_mergeable(self, i, j):
one_leaf = self.degree[i] == 1 or self.degree[j] == 1
branching = self.degree[i] > 2 or self.degree[j] > 2
not_branching = self.degree[i] < 3 and self.degree[j] < 3
somas_check = not (self.is_soma(i) and self.is_soma(j))
return somas_check and (one_leaf and not branching)
return somas_check and (one_leaf and not_branching)

def is_single_proposal(self, proposal):
"""
Expand Down
3 changes: 3 additions & 0 deletions src/neuron_proofreader/split_proofreading/split_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __iter__(self):
yield HeteroGraphData(features)

# --- Helpers ---
def __getattr__(self, name):
return getattr(self.graph, name)

def get_sampler(self):
"""
Gets a subgraph sampler that is used to iterate over dataset.
Expand Down
110 changes: 59 additions & 51 deletions src/neuron_proofreader/split_proofreading/split_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ def _load_data(self, fragments_path, img_path, segmentation_path):
segmentation_path=segmentation_path,
soma_centroids=self.soma_centroids,
)
self.log(self.dataset.graph.summary(prefix="\nInitial"))
self.log(self.dataset.summary(prefix="\nInitial"))
self.save_fragment_ids()
self.save_graph("original_swcs")

# Postprocess fragments with somas
self.log(self.dataset.graph.remove_soma_merges())
self.log(self.dataset.graph.connect_soma_fragments())
self.log(self.dataset.remove_soma_merges())
self.log(self.dataset.connect_soma_fragments())

# Break high risk merges (if applicable)
if self.config.graph.remove_high_risk_merges:
self.log(self.dataset.graph.remove_high_risk_merges())
self.log(self.dataset.graph.summary(prefix="\nPre-Corrected"))
self.log(self.dataset.remove_high_risk_merges())
self.log(self.dataset.summary(prefix="\nPre-Corrected"))
self.save_graph("precorrected_swcs")

# Report runtime
Expand All @@ -165,48 +165,56 @@ def __call__(self, search_radius):

# Report results
t, unit = util.time_writer(time() - t0)
self.log(self.dataset.graph.summary(prefix="\nFinal"))
self.log(self.dataset.summary(prefix="\nFinal"))
self.log(f"Total Runtime: {t:.2f} {unit}\n")
self.save_results()

def multistep(
self, search_radius, low_threshold=0.3, high_threshold=0.8
):
# Generate proposals
def multistep(self, search_radius, low_threshold=0.3, dt=0.1):
t0 = time()
self.generate_proposals(search_radius)
total_proposals = self.dataset.n_proposals()
for only_leaf2leaf in [True, False]:
cnt = 0
name = "_leaf2leaf" if only_leaf2leaf else ""
new_threshold = 0.99
while self.dataset.proposals:
# Generate predictons
cnt += 1
print(f"Threshold={new_threshold} w/ only_leaf2leaf={only_leaf2leaf}")
preds = self.predict_proposals(suffix=f"{name}_round={cnt}")

# Round 1: Update graph
preds = self.predict_proposals(suffix="_round1")
self.merge_with_threshold_schedule(
preds, high_threshold, only_leaf2leaf=True
)
self.filter_proposals(preds, low_threshold)
# Merge accetped proposals
cur_threshold = new_threshold
self.merge_with_threshold_schedule(
preds, cur_threshold, only_leaf2leaf=only_leaf2leaf
)
self.filter_proposals(preds, low_threshold)

# Round 2: Update graph
preds = self.predict_proposals()
self.merge_with_threshold_schedule(
preds, self.config.ml.threshold, only_leaf2leaf=False
)
# Update threshold
new_threshold = max(cur_threshold - dt, self.config.ml.threshold)
if cur_threshold == new_threshold:
break

# Report results
t, unit = util.time_writer(time() - t0)
self.log(self.dataset.graph.summary(prefix="\nFinal"))
p_accepts = len(self.dataset.accepts) / total_proposals
self.log(self.dataset.summary(prefix="\nFinal"))
self.log(f"Overall Acceptance Rate: {p_accepts:.2f}")
self.log(f"Total Runtime: {t:.2f} {unit}\n")
self.save_results()

# --- Core Routines ---
def filter_proposals(self, preds, threshold):
cnt = 0
for proposal, pred in preds.items():
is_valid = self.dataset.graph.is_mergeable(*proposal)
is_valid = self.dataset.is_mergeable(*proposal)
if pred < threshold or not is_valid:
self.dataset.graph.remove_proposal(proposal)
self.dataset.remove_proposal(proposal)
cnt += 1

self.log("\nFilter Proposals")
self.log(f"# Proposals Removed: {cnt}")
self.log(f"# Proposals Remaining: {self.dataset.graph.n_proposals()}")
self.log(f"# Proposals Remaining: {self.dataset.n_proposals()}\n")

def generate_proposals(self, search_radius):
"""
Expand All @@ -222,13 +230,13 @@ def generate_proposals(self, search_radius):
t0 = time()
self.log("\nStep 2: Generate Proposals")
self.log(f"Search Radius: {search_radius}")
self.dataset.graph.generate_proposals(
self.dataset.generate_proposals(
search_radius,
allow_nonleaf_proposals=self.config.graph.allow_nonleaf_proposals,
)

n_proposals = format(self.dataset.graph.n_proposals(), ",")
n_proposals_blocked = self.dataset.graph.n_proposals_blocked
n_proposals = format(self.dataset.n_proposals(), ",")
n_proposals_blocked = self.dataset.n_proposals_blocked

# Report results
t, unit = util.time_writer(time() - t0)
Expand Down Expand Up @@ -259,7 +267,7 @@ def merge_with_threshold_schedule(
# Initializations
t0 = time()
self.log("\nStep 3: Run Inference")
n_proposals = self.dataset.graph.n_proposals()
n_proposals = self.dataset.n_proposals()
n_accepts = 0

# Progressive merging
Expand All @@ -278,7 +286,7 @@ def merge_with_threshold_schedule(

# Report results
t, unit = util.time_writer(time() - t0)
self.log(f"# Merges Blocked: {self.dataset.graph.n_merges_blocked}")
self.log(f"# Merges Blocked: {self.dataset.n_merges_blocked}")
self.log(f"# Accepted: {format(n_accepts, ',')}")
self.log(f"% Accepted: {100 * n_accepts / n_proposals:.2f}")
self.log(f"Module Runtime: {t:.2f} {unit}\n")
Expand All @@ -295,7 +303,7 @@ def predict_proposals(self, suffix=""):
"""
# Main
preds = dict()
pbar = tqdm(total=self.dataset.graph.n_proposals(), desc="Inference")
pbar = tqdm(total=self.dataset.n_proposals(), desc="Inference")
for data in self.dataset:
preds.update(self.predict(data))
pbar.update(data.n_proposals())
Expand All @@ -321,21 +329,21 @@ def merge_proposals(self, preds, threshold, only_leaf2leaf=False):
is False.
"""
n_accepts = 0
proposals = self.dataset.graph.sorted_proposals()
proposals = self.dataset.sorted_proposals()
for proposal in [p for p in proposals if p in preds]:
# Check for leaf2leaf condition
is_leaf2leaf = self.dataset.graph.is_leaf2leaf(proposal)
is_leaf2leaf = self.dataset.is_leaf2leaf(proposal)
if only_leaf2leaf and not is_leaf2leaf:
continue

# Check if proposal satifies threshold
i, j = proposal
if preds[proposal] < threshold:
continue

# Check if proposal creates a loop
i, j = proposal
if not nx.has_path(self.dataset.graph, i, j):
self.dataset.graph.merge_proposal(proposal)
self.dataset.merge_proposal(proposal)
n_accepts += 1
del preds[proposal]
return n_accepts
Expand Down Expand Up @@ -399,7 +407,7 @@ def save_graph(self, dirname):
util.mkdir(os.path.join(self.output_dir, dirname))

# Save swcs
self.dataset.graph.to_zipped_swcs_multithreaded(temp_dir)
self.dataset.to_zipped_swcs_multithreaded(temp_dir)
zip_paths = util.list_paths(temp_dir, extension=".zip")
util.combine_zips(zip_paths, output_zip_path)
util.rmdir(temp_dir)
Expand All @@ -409,22 +417,22 @@ def save_proposal_results(self, preds_dict, suffix=""):
for proposal, pred in preds_dict.items():
# Extract info
i, j = proposal
segment_i = self.dataset.graph.node_swc_id(i)
segment_j = self.dataset.graph.node_swc_id(j)
segment_i = self.dataset.node_swc_id(i)
segment_j = self.dataset.node_swc_id(j)

# Add info
summary.append(
{
"Proposal": (segment_i, segment_j),
"Leaf2Leaf": self.dataset.graph.is_leaf2leaf(proposal),
"Length": self.dataset.graph.proposal_length(proposal),
"Leaf2Leaf": self.dataset.is_leaf2leaf(proposal),
"Length": self.dataset.proposal_length(proposal),
"Prediction": pred,
"Segment1": segment_i,
"Segment2": segment_j,
"Voxel1": self.dataset.graph.node_voxel(i),
"Voxel2": self.dataset.graph.node_voxel(j),
"World1": self.dataset.graph.node_xyz[i],
"World2": self.dataset.graph.node_xyz[j],
"Voxel1": self.dataset.node_voxel(i),
"Voxel2": self.dataset.node_voxel(j),
"World1": self.dataset.node_xyz[i],
"World2": self.dataset.node_xyz[j],
}
)

Expand All @@ -433,11 +441,11 @@ def save_proposal_results(self, preds_dict, suffix=""):
pd.DataFrame(summary).set_index("Proposal").to_csv(path)

def reconfigure_node_radius(self):
n_nodes = len(self.dataset.graph.node_radius)
self.dataset.graph.node_radius = np.ones((n_nodes), dtype=np.float16)
for i, j in self.dataset.graph.accepts:
self.dataset.graph.node_radius[i] = 6
self.dataset.graph.node_radius[j] = 6
n_nodes = len(self.dataset.node_radius)
self.dataset.node_radius = np.ones((n_nodes), dtype=np.float16)
for i, j in self.dataset.accepts:
self.dataset.node_radius[i] = 6
self.dataset.node_radius[j] = 6

def save_connections(self):
"""
Expand All @@ -446,10 +454,10 @@ def save_connections(self):
"""
path = os.path.join(self.output_dir, "connections.txt")
with open(path, "w") as f:
for id1, id2 in self.dataset.graph.merged_ids:
for id1, id2 in self.dataset.merged_ids:
f.write(f"{id1}, {id2}" + "\n")

def save_fragment_ids(self):
path = f"{self.output_dir}/segment_ids.txt"
segment_ids = list(self.dataset.graph.component_id_to_swc_id.values())
segment_ids = list(self.dataset.component_id_to_swc_id.values())
util.write_list(path, segment_ids)
Loading