Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/neuron_proofreader/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def run(self, train_dataloader, val_dataloader):
# Step scheduler
self.scheduler.step()

def train_step(self, train_dataloader, epoch):
def train_step(self, dataloader, epoch):
"""
Performs a single training epoch over the provided DataLoader.

Parameters
----------
train_dataloader : torch.utils.data.DataLoader
dataloader : torch.utils.data.DataLoader
DataLoader for the training dataset.
epoch : int
Current training epoch.
Expand All @@ -159,7 +159,7 @@ def train_step(self, train_dataloader, epoch):
"""
self.model.train()
loss, y, hat_y = list(), list(), list()
for x_i, y_i in train_dataloader:
for x_i, y_i in dataloader:
# Forward pass
self.optimizer.zero_grad()
hat_y_i, loss_i = self.forward_pass(x_i, y_i)
Expand All @@ -184,13 +184,13 @@ def train_step(self, train_dataloader, epoch):
self.update_tensorboard(stats, epoch, "train_")
return stats

def validate_step(self, val_dataloader, epoch):
def validate_step(self, dataloader, epoch):
"""
Performs a full validation loop over the given dataloader.

Parameters
----------
val_dataloader : torch.utils.data.DataLoader
dataloader : torch.utils.data.DataLoader
DataLoader for the validation dataset.
epoch : int
Current training epoch.
Expand All @@ -213,7 +213,7 @@ def validate_step(self, val_dataloader, epoch):
# Iterate over dataset
self.model.eval()
with torch.no_grad():
for x, y in val_dataloader:
for x, y in dataloader:
# Run model
hat_y, loss = self.forward_pass(x, y)

Expand Down
55 changes: 14 additions & 41 deletions src/neuron_proofreader/merge_proofreading/merge_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,15 @@ def get_random_negative_site(self):
label : int
Label of example.
"""
# Sample graph
brain_id = self.sample_brain_id()

# Sample node on graph
outcome = random.random()
cnt = 0
while True:
# Sample node
cnt += 1
if outcome < 0.4:
# Any node
node = util.sample_once(self.graphs[brain_id].nodes)
#elif outcome < 0.5:
# # Node close to soma
# node = self.sample_node_nearby_soma(brain_id)
elif outcome < 0.8:
# Branching node
branching_nodes = self.graphs[brain_id].branching_nodes()
Expand Down Expand Up @@ -460,6 +456,10 @@ def get_random_negative_site(self):
if not self.is_nearby_merge_site(brain_id, node):
return brain_id, subgraph, 0

# Check number of tries
if cnt > 20:
outcome = 1

def get_img_patch(self, brain_id, center):
"""
Extracts and normalizes a 3D image patch from the specified whole-
Expand Down Expand Up @@ -529,7 +529,7 @@ def __len__(self):
int
Number of positive examples of merge sites.
"""
return len(self.merge_sites_df)
return 2 * len(self.merge_sites_df)

def check_nearby_branching(
self, brain_id, root, max_depth=60, use_gt=False
Expand Down Expand Up @@ -618,37 +618,6 @@ def is_nearby_merge_site(self, brain_id, node):
dist, _ = self.merge_site_kdtrees[brain_id].query(xyz)
return dist < 100

def relabel_nodes(self):
"""
Reassigns contiguous node IDs and update all dependent structures.
"""
# Set node ids
old_node_ids = np.array(self.nodes, dtype=int)
new_node_ids = np.arange(len(old_node_ids))

# Set edge ids
old_to_new = dict(zip(old_node_ids, new_node_ids))
old_edge_ids = list(self.edges)
old_irr_edge_ids = self.irreducible.edges
edge_attrs = {(i, j): data for i, j, data in self.edges(data=True)}

# Reset graph
self.clear()
for (i, j) in old_edge_ids:
self.add_edge(old_to_new[i], old_to_new[j], **edge_attrs[(i, j)])

self.irreducible.clear()
for (i, j) in old_irr_edge_ids:
self.irreducible.add_edge(old_to_new[i], old_to_new[j])

# Update attributes
self.node_radius = self.node_radius[old_node_ids]
self.node_xyz = self.node_xyz[old_node_ids]
self.node_component_id = self.node_component_id[old_node_ids]

self.reassign_component_ids()
self.set_kdtree()

def sample_node_nearby_soma(self, brain_id):
subgraph = self.gt_graphs[brain_id].get_rooted_subgraph(0, 600)
gt_node = util.sample_once(subgraph.nodes)
Expand Down Expand Up @@ -730,7 +699,7 @@ def get_site(self, idx):
return self.get_indexed_positive_site(idx)
elif np.random.random() < self.random_negative_example_prob:
return self.get_random_negative_site()
elif abs(idx) < len(self):
elif abs(idx) < len(self.merge_sites_df):
return self.get_indexed_negative_site(abs(idx))
else:
return self.get_random_negative_site()
Expand All @@ -745,8 +714,9 @@ def get_idxs(self):
numpy.ndarray
Example indices to iterate over.
"""
n_negative_examples = int(len(self) * (1 + self.negative_bias))
return np.arange(-n_negative_examples + 1, len(self))
n_pos_examples = len(self.merge_sites_df)
n_negative_examples = int(n_pos_examples * (1 + self.negative_bias))
return np.arange(-n_negative_examples + 1, n_pos_examples)


class MergeSiteValDataset(MergeSiteDataset):
Expand Down Expand Up @@ -1134,3 +1104,6 @@ def _load_image_graph_batch(self, idxs):
}
)
return batch, ml_util.to_tensor(targets)

def __len__(self):
return 2 * len(self.dataset)
Loading