Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions Functional_Fusion/atlas_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,46 @@ def parcel_recombine(label_vector,parcels_selected,label_id=None,label_name=None
raise ValueError('parcels_selected must be a list')
return label_vector_new, label_id_new, label_name_new


def parcel_combine(img, output_filename=None):
"""
Combines multiple ROI mask NIfTI files into a single NIfTI file where each ROI has a unique integer label.

Parameters:
- roi_files (list of str or Nifti1Image): List of paths to NIfTI mask files or list of NIfTI mask files.
- output_filename (str): Path to save the combined NIfTI file.

Returns:
- Saves a NIfTI file where each ROI has a unique label.
"""
# Load the first image to get shape and affine transformation
if isinstance(img[0], str):
reference_img = nb.load(img[0])
if isinstance(img[0], nb.Nifti1Image):
reference_img = img[0]
combined_data = np.zeros(reference_img.shape, dtype=np.int16)

# Assign unique labels to each ROI
for i, mask in enumerate(img, start=1):
if isinstance(mask, str):
roi_img = nb.load(mask)
if isinstance(mask, nb.Nifti1Image):
roi_img = mask
roi_data = roi_img.get_fdata()

# Ensure binary mask (in case input masks have non-binary values)
roi_mask = roi_data > 0

# Assign a unique label to this ROI
combined_data[roi_mask] = i

# Save the combined ROI mask as a new NIfTI file
combined_img = nb.Nifti1Image(combined_data, reference_img.affine, reference_img.header)
nb.save(combined_img, output_filename)

return combined_img


class Atlas:
def __init__(self, name, structure='unknown', space='unknown'):
""" The Atlas class implements the mapping from the P brain locations back to the defining
Expand Down Expand Up @@ -1090,12 +1130,15 @@ def build(self, depths=[0, 0.2, 0.4, 0.6, 0.8, 1.0]):
indices[i, :, :] = (1 - depths[i]) * c1 + depths[i] * c2

self.vox_list, good = nt.coords_to_linvidxs(indices, self.mask_img, mask=True)
all = good.sum(axis=0)
# all = good.sum(axis=0)
_, invx, count = np.unique(self.vox_list, return_inverse=True, return_counts=True)
# print(f'{self.name} has {np.sum(all==0)} vertices without data')
all[all == 0] = 1
self.vox_weight = good / all
# all[all == 0] = 1
self.vox_weight = count[invx]# good / all
self.vox_list = self.vox_list.T
self.vox_weight = self.vox_weight.T
self.vox_weight = self.vox_weight.T

pass

def get_data_nifti(fnames, atlas_maps):
"""Extracts the data for a list of fnames
Expand Down Expand Up @@ -1201,9 +1244,27 @@ def exclude_overlapping_voxels(amap, exclude='all', exclude_thres=0.9):
vox_j, weight_j = amap[j].vox_list, amap[j].vox_weight
vox_k, weight_k = amap[k].vox_list, amap[k].vox_weight

EQ = vox_j.flatten()[:, np.newaxis] == vox_k.flatten()[np.newaxis, :]
# EQ = vox_j.flatten()[:, np.newaxis] == vox_k.flatten()[np.newaxis, :]
#
# idx_j, idx_k = np.where(EQ)
vox_j = vox_j.flatten()
vox_k = vox_k.flatten()

# Sort vox_k and keep track of the original indices
sort_idx = np.argsort(vox_k)
vox_k_sorted = vox_k[sort_idx]

# Check which elements in vox_j exist in vox_k
mask = np.isin(vox_j, vox_k_sorted)

# Find the corresponding indices in vox_k
idx_j = np.where(mask)[0] # Indices in vox_j
idx_k = np.searchsorted(vox_k_sorted, vox_j[mask]) # Indices in sorted vox_k

idx_j, idx_k = np.where(EQ)
# Convert back to original vox_k indices
idx_k = sort_idx[idx_k]

print(f'found {len(idx_j)} overlapping voxels')

for idx_j_v, idx_k_v in zip(idx_j, idx_k):
wj, wk = weight_j.flatten()[idx_j_v], weight_k.flatten()[idx_k_v]
Expand Down
4 changes: 2 additions & 2 deletions Functional_Fusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def agg_data(info, by, over, subset=None):
return data_info, C


def agg_parcels(data, label_vec, fcn=np.nanmean):
def agg_parcels(data, label_vec, fcn=np.nanmean, **kwargs):
""" Aggregates data over colums to condense to parcels

Args:
Expand All @@ -248,7 +248,7 @@ def agg_parcels(data, label_vec, fcn=np.nanmean):
parcel_data = np.zeros(psize)
for i, l in enumerate(labels):
parcel_data[..., i] = fcn(
data[..., label_vec == l], axis=len(psize) - 1)
data[..., label_vec == l], axis=len(psize) - 1, **kwargs)
return parcel_data, labels

def combine_parcel_labels(labels_org,labels_new, labelvec_org=None):
Expand Down