Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions src/neuron_proofreader/merge_proofreading/merge_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,37 @@ def get_detected_sites(self, threshold):
nodes = np.where(self.node_preds >= threshold)[0]
return [self.dataset.graph.node_xyz[i] for i in nodes]

def save_parameters(self, output_dir):
json_path = os.path.join(output_dir, "detection_parameters.json")
parameters = {
"accept_threshold": self.threshold,
"is_multimodal": self.dataset.is_multimodal,
"min_search_size": self.dataset.min_size,
"patch_shape": self.patch_shape,
"remove_detected_sites": self.remove_detected_sites,
"search_mode": self.dataset.search_mode,
"subgraph_radius": self.dataset.subgraph_radius,
}
util.write_json(json_path, parameters)

def save_results(
self, output_dir, output_prefix_s3=None, save_fragments=True
):
self.save_sites(output_dir)
if save_fragments:
fragments_path = os.path.join(output_dir, "fragments.zip")
self.dataset.graph.to_zipped_swcs(fragments_path)

# Upload results to S3 (if applicable)
if output_prefix_s3:
bucket_name, prefix = util.parse_cloud_path(output_prefix_s3)
util.upload_dir_to_s3(output_dir, bucket_name, prefix)

def save_sites(self, output_dir):
# Get predicted merge sites
nodes = np.where(self.node_preds >= self.threshold)[0]
detected_sites = [self.dataset.graph.node_xyz[i] for i in nodes]
print("# Sites Saved:", len(nodes))

# Save predicted merge sites
zip_path = os.path.join(output_dir, "detected_sites.zip")
Expand All @@ -191,28 +216,21 @@ def save_results(
radius=10,
)

# Save fragments
if save_fragments:
fragments_path = os.path.join(output_dir, "fragments.zip")
self.dataset.graph.to_zipped_swcs(fragments_path)

# Upload results to S3 (if applicable)
if output_prefix_s3:
bucket_name, prefix = util.parse_cloud_path(output_prefix_s3)
util.upload_dir_to_s3(output_dir, bucket_name, prefix)
def save_train_dataset(self, output_dir):
# Extract fragments to save
roots = list()
visited_ids = set()
for i in np.where(self.node_preds >= self.threshold)[0]:
cc_id = self.dataset.graph.node_component_ids[i]
if cc_id not in visited_ids:
roots.append(i)
visited_ids.add(cc_id)

def save_parameters(self, output_dir):
json_path = os.path.join(output_dir, "detection_parameters.json")
parameters = {
"accept_threshold": self.threshold,
"is_multimodal": self.dataset.is_multimodal,
"min_search_size": self.dataset.min_size,
"patch_shape": self.patch_shape,
"remove_detected_sites": self.remove_detected_sites,
"search_mode": self.dataset.search_mode,
"subgraph_radius": self.dataset.subgraph_radius,
}
util.write_json(json_path, parameters)
# Save fragments
zip_path = os.path.join(output_dir, "fragments.zip")
self.dataset.graph._batch_to_zipped_swcs(roots, zip_path, False)
self.save_sites(output_dir)
print("# Fragments Saved:", len(roots))


# --- Data Handling ---
Expand Down Expand Up @@ -297,7 +315,7 @@ def find_fragments_to_search(self):
for nodes in nx.connected_components(self.graph):
# Compute path length
node = util.sample_once(list(nodes))
length = self.graph.path_length(root=node, max_depth=self.min_size)
length = self.graph.cable_length(max_depth=self.min_size, root=node)

# Check if path length satisfies threshold
if length > self.min_size:
Expand Down
Loading