Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/neuron_proofreader/machine_learning/exaspim_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def __init__(
self.skeletons = dict()

# --- Ingest Data ---
def ingest_brain(self, brain_id, img_path, segmentation_path, swc_path):
def ingest_brain(
self, brain_id, img_path, segmentation_path=None, swc_path=None
):
"""
Loads a brain image, label mask, and skeletons, then stores each in
internal dictionaries.
Expand All @@ -99,21 +101,25 @@ def ingest_brain(self, brain_id, img_path, segmentation_path, swc_path):
Unique identifier for the brain corresponding to the image.
img_path : str
Path to whole-brain image to be read.
segmentation_path : str
Path to segmentation.
swc_path : str
Path to SWC files.
segmentation_path : str, optional
Path to segmentation. Default is None.
swc_path : str, optional
Path to SWC files. Default is None.
"""
# Load data
self.imgs[brain_id] = TensorStoreReader(img_path)
self.segmentations[brain_id] = TensorStoreReader(segmentation_path)
self._load_segmentation(brain_id, segmentation_path)
self._load_swcs(brain_id, swc_path)

# Check image shapes
shape1 = self.imgs[brain_id].shape()[2::]
shape2 = self.segmentations[brain_id].shape()
assert shape1 == shape2, f"img_shape={shape1}, mask_shape={shape2}"

def _load_segmentation(self, brain_id, path):
if path:
self.segmentations[brain_id] = TensorStoreReader(path)

def _load_swcs(self, brain_id, swc_path):
if swc_path:
# Initializations
Expand Down
6 changes: 1 addition & 5 deletions src/neuron_proofreader/split_proofreading/split_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self,
fragments_path,
img_path,
model_path,
output_dir,
model,
config,
Expand All @@ -73,8 +72,6 @@ def __init__(
Path to SWC files to be loaded into graph.
img_path : str
Path to whole-brain image corresponding to the given fragments.
model_path : str
Path to checkpoint file containing model weights.
output_dir : str
Directory where the results of the inference will be saved.
config : Config
Expand All @@ -93,7 +90,7 @@ def __init__(
self.accepted_proposals = list()
self.config = config
self.img_path = img_path
self.model = model
self.model = model.to(config.ml.device)
self.output_dir = output_dir
self.soma_centroids = soma_centroids

Expand All @@ -105,7 +102,6 @@ def __init__(

# Load data
self._load_data(fragments_path, img_path, segmentation_path)
ml_util.load_model(self.model, model_path, device=config.ml.device)

def _load_data(self, fragments_path, img_path, segmentation_path):
"""
Expand Down
9 changes: 1 addition & 8 deletions src/neuron_proofreader/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,7 @@ def load_model(model, model_path, device="cuda"):
device : str, optional
Device to load the model onto. Default is "cuda".
"""
state_dict = torch.load(model_path, map_location=device)
fixed_state_dict = {}
for k, v in state_dict.items():
if k.startswith("output.") and not k.startswith("output.net."):
k = k.replace("output.", "output.net.", 1)
fixed_state_dict[k] = v

model.load_state_dict(fixed_state_dict, strict=False)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

Expand Down
Loading