Skip to content
2 changes: 1 addition & 1 deletion afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def from_study(study, verbose=None):
dataset_kwargs = {
"sarica": {
"dwi_metrics": ["md", "fa"],
"target_cols": ["class"],
"target_cols": ["class", "age"],
"label_encode_cols": ["class"],
},
"weston-havens": {"dwi_metrics": ["md", "fa"], "target_cols": ["Age"]},
Expand Down
143 changes: 97 additions & 46 deletions afqinsight/parametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Perform linear modeling at leach node along the tract."""
"""Perform linear modeling at each node along the tract."""

import numpy as np
import pandas as pd
Expand All @@ -11,11 +11,11 @@
def node_wise_regression(
afq_dataset,
tract,
metric,
formula,
group="group",
group=None,
lme=False,
rand_eff="subjectID",
impute="median",
):
"""Model group differences using node-wise regression along the length of the tract.

Expand All @@ -26,13 +26,10 @@ def node_wise_regression(
----------
afq_dataset: AFQDataset
Loaded AFQDataset object

tract: str
String specifying the tract to model

metric: str
String specifying which diffusion metric to use as an outcome
eg. 'fa'

formula: str
An R-style formula <https://www.statsmodels.org/dev/example_formulas.html>
specifying the regression model to fit at each node. This can take the form
Expand All @@ -46,20 +43,23 @@ def node_wise_regression(
mixed-effects models. If using anything other than the default value,
this column must be present in the 'target_cols' of the AFQDataset object

impute: str or None, default='median'
String specifying the imputation strategy to use for missing data.


Returns
-------
tract_dict: dict
A dictionary with the following key-value pairs:

{'tract': tract,
'reference_coefs': coefs_default,
'group_coefs': coefs_treat,
'reference_CI': cis_default,
'group_CI': cis_treat,
'pvals': pvals,
'reject_idx': reject_idx,
'model_fits': fits}
'reference_coefs': coefs_default,
'group_coefs': coefs_treat,
'reference_CI': cis_default,
'group_CI': cis_treat,
'pvals': pvals,
'reject_idx': reject_idx,
'model_fits': fits}

tract: str
The tract described by this dictionary
Expand All @@ -72,7 +72,7 @@ def node_wise_regression(
group_coefs: list of floats
A list of beta-weights representing the average group effect metric
for the treatment group on a diffusion metric at a given location
along the tract
along the tract, if group None this will be a list of zeros.

reference_CI: np.array of np.array
A numpy array containing a series of numpy arrays indicating the
Expand All @@ -82,7 +82,8 @@ def node_wise_regression(
group_CI: np.array of np.array
A numpy array containing a series of numpy arrays indicating the
95% confidence interval around the estimated beta-weight of the
treatment effect at a given location along the tract
treatment effect at a given location along the tract. If group is
None, this will be an array of zeros.

pvals: list of floats
A list of p-values testing whether or not the beta-weight of the
Expand All @@ -96,8 +97,13 @@ def node_wise_regression(
A list of the statsmodels object fit along the length of the nodes

"""
X = SimpleImputer(strategy="median").fit_transform(afq_dataset.X)
afq_dataset.target_cols[0] = group
if impute is not None:
X = SimpleImputer(strategy=impute).fit_transform(afq_dataset.X)

if group is not None:
afq_dataset.target_cols[0] = group

metric = formula.split("~")[0].strip()

tract_data = (
pd.DataFrame(columns=afq_dataset.feature_names, data=X)
Expand All @@ -106,12 +112,13 @@ def node_wise_regression(
)

pvals = np.zeros(tract_data.shape[-1])
pvals_corrected = np.zeros(tract_data.shape[-1])
coefs_default = np.zeros(tract_data.shape[-1])
coefs_treat = np.zeros(tract_data.shape[-1])
cis_default = np.zeros((tract_data.shape[-1], 2))
cis_treat = np.zeros((tract_data.shape[-1], 2))
reject = np.zeros(tract_data.shape[-1], dtype=bool)
fits = {}

# Loop through each node and fit model
for ii, column in enumerate(tract_data.columns):
# fit linear mixed-effects model
Expand All @@ -125,7 +132,6 @@ def node_wise_regression(

model = smf.mixedlm(formula, this, groups=rand_eff)
fit = model.fit()
fits[column] = fit

# fit OLS model
else:
Expand All @@ -135,31 +141,76 @@ def node_wise_regression(

model = OLS.from_formula(formula, this)
fit = model.fit()
fits[column] = fit

fits[ii] = fit
# pull out coefficients, CIs, and p-values from our model
coefs_default[ii] = fit.params.filter(regex="Intercept", axis=0).iloc[0]
coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0]

cis_default[ii] = (
fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values
)
cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values
pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0]

# Correct p-values for multiple comparisons
reject, pval_corrected, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh")
reject_idx = np.where(reject)

tract_dict = {
"tract": tract,
"reference_coefs": coefs_default,
"group_coefs": coefs_treat,
"reference_CI": cis_default,
"group_CI": cis_treat,
"pvals": pvals,
"reject_idx": reject_idx,
"model_fits": fits,
}

return tract_dict

if group is not None:
coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0]

cis_default[ii] = (
fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values
)
cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values
pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0]

# Correct p-values for multiple comparisons
reject, pvals_corrected, _, _ = multipletests(
pvals, alpha=0.05, method="fdr_bh"
)

reject = np.where(reject, 1, 0)

return pd.DataFrame(
{
"reference_coefs": coefs_default,
"group_coefs": coefs_treat,
"reference_CI_lb": cis_default[:, 0],
"reference_CI_ub": cis_default[:, 1],
"group_CI_lb": cis_treat[:, 0],
"group_CI_ub": cis_treat[:, 1],
"pvals": pvals,
"pvals_corrected": pvals_corrected,
"reject_idx": reject,
}
), fits


class RegressionResults(object):
def __init__(self, kwargs):
self.tract = kwargs.get("tract", None)
self.reference_coefs = kwargs.get("reference_coefs", None)
self.group_coefs = kwargs.get("group_coefs", None)
self.reference_ci = kwargs.get("reference_ci", None)
self.group_ci = kwargs.get("group_ci", None)
self.pvals = kwargs.get("pvals", None)
self.pvals_corrected = kwargs.get("pvals_corrected", None)
self.reject_idx = kwargs.get("reject_idx", None)
self.model_fits = kwargs.get("model_fits", None)


class NodeWiseRegression(object):
def __init__(self, formula, lme=False):
self.formula = formula
self.lme = lme

def fit(self, dataset, tracts, group=None, rand_eff="subjectID"):
self.result_ = {}
for tract in tracts:
self.result_[tract] = node_wise_regression(
dataset,
tract,
self.formula,
lme=self.lme,
group=group,
rand_eff=rand_eff,
)
self.is_fitted = True
return self

def predict(self, dataset, tract, metric, group="group", rand_eff="subjectID"):
if not self.is_fitted:
raise ValueError("Model not fitted yet. Please call fit() method first.")
result = self.result_.get(tract, None)
if result is None:
raise ValueError(f"Tract {tract} not found in the fitted model.")
2 changes: 1 addition & 1 deletion afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_from_study(study):
"n_subjects": 48,
"n_features": 4000,
"n_groups": 40,
"target_cols": ["class"],
"target_cols": ["class", "age"],
},
"weston-havens": {
"n_subjects": 77,
Expand Down
43 changes: 43 additions & 0 deletions afqinsight/tests/test_parametric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from afqinsight import AFQDataset
from afqinsight.parametric import NodeWiseRegression, node_wise_regression


def test_node_wise_regression():
# Store results
group_dict = {}
group_age_dict = {}
age_dict = {}

data = AFQDataset.from_study("sarica")
tracts = ["Right Corticospinal", "Right SLF"]
for tract in tracts:
for lme in [True, False]:
# Run different versions of this: with age, without age, only with
# age:

group_dict[tract] = node_wise_regression(
data, tract, "fa ~ C(group)", lme=lme, group="group"
)
group_age_dict[tract] = node_wise_regression(
data, tract, "fa ~ C(group) + age", lme=lme, group="group"
)
age_dict[tract] = node_wise_regression(data, tract, "fa ~ age", lme=lme)

assert group_dict[tract]["pvals"].shape == (100,)
assert group_age_dict[tract]["pvals"].shape == (100,)
assert age_dict[tract]["pvals"].shape == (100,)

assert np.any(group_dict["Right Corticospinal"]["pvals_corrected"] < 0.05)
assert np.all(group_dict["Right SLF"]["pvals_corrected"] > 0.05)
assert np.any(group_age_dict["Right Corticospinal"]["pvals_corrected"] < 0.05)
assert np.all(group_age_dict["Right SLF"]["pvals_corrected"] > 0.05)


def test_NodeWiseRegression():
data = AFQDataset.from_study("sarica")
tracts = ["Left Corticospinal", "Left SLF"]
for lme in [True, False]:
model = NodeWiseRegression("fa ~ C(group) + age", lme=lme)
model.fit(data, tracts, group="group")
3 changes: 2 additions & 1 deletion examples/plot_als_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

X = afqdata.X
y = afqdata.y.astype(float) # SGL expects float targets
is_als = y[:, 0]
groups = afqdata.groups
feature_names = afqdata.feature_names
group_names = afqdata.group_names
Expand Down Expand Up @@ -117,7 +118,7 @@
# scikit-learn functions

scores = cross_validate(
pipe, X, y, cv=5, return_train_score=True, return_estimator=True
pipe, X, is_als, cv=5, return_train_score=True, return_estimator=True
)

# Display results
Expand Down
8 changes: 4 additions & 4 deletions examples/plot_als_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@


# Loop through the data and generate plots
for i, tract in enumerate(tracts):
for ii, tract in enumerate(tracts):
# fit node-wise regression for each tract based on model formula
tract_dict = node_wise_regression(afqdata, tract, "fa", "fa ~ C(group)")
tract_dict = node_wise_regression(afqdata, tract, "fa ~ C(group)", group="group")

row = i // num_cols
col = i % num_cols
row = ii // num_cols
col = ii % num_cols

axes[row][col].set_title(tract)

Expand Down
Loading