Skip to content

Commit f44b399

Browse files
committed
refactor: consolidate shared components and reorganize schema modules
- Move common functions (risk categorization, calibration, gompertz model) to helpers.py - Consolidate schema modules into algorithm-specific files (phenoage.py, score2.py) - Rename coefficient classes for clarity and consistency - Update imports across all compute modules to use new schema structure - Remove redundant schema files (core.py, markers.py, units.py)
1 parent 30324a0 commit f44b399

15 files changed

Lines changed: 272 additions & 285 deletions

File tree

tests/test_phenoage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vitals.phenoage import compute
5+
from vitals.models import phenoage
66

77
OUT_FILEPATH = Path(__file__).parent / "inputs" / "phenoage"
88

@@ -25,7 +25,7 @@
2525
)
2626
def test_phenoage(filename, expected):
2727
# Get the actual fixture value using request.getfixturevalue
28-
age, pred_age, accl_age = compute.biological_age(OUT_FILEPATH / filename)
28+
age, pred_age, accl_age = phenoage.compute(OUT_FILEPATH / filename)
2929
expected_age, expected_pred_age, expected_accl_age = expected
3030

3131
assert age == expected_age

tests/test_score2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vitals.score2 import compute
5+
from vitals.models import score2
66

77
OUT_FILEPATH = Path(__file__).parent / "inputs" / "score2"
88

@@ -26,9 +26,7 @@
2626
)
2727
def test_score2(filename, expected):
2828
# Get the actual fixture value using request.getfixturevalue
29-
age, pred_risk, pred_risk_category = compute.cardiovascular_risk(
30-
OUT_FILEPATH / filename
31-
)
29+
age, pred_risk, pred_risk_category = score2.compute(OUT_FILEPATH / filename)
3230
expected_age, expected_risk, expected_category = expected
3331

3432
assert age == expected_age

tests/test_score2_diabetes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vitals.score2_diabetes import compute
5+
from vitals.models import score2_diabetes
66

77
OUT_FILEPATH = Path(__file__).parent / "inputs" / "score2_diabetes"
88

@@ -26,7 +26,7 @@ def test_score2_diabetes(filename, expected):
2626
They need to be calculated using MDCalc and updated before running tests.
2727
"""
2828
# Get the actual fixture value
29-
age, pred_risk, pred_risk_category = compute.cardiovascular_risk(
29+
age, pred_risk, pred_risk_category = score2_diabetes.compute(
3030
OUT_FILEPATH / filename
3131
)
3232
expected_age, expected_risk, expected_category = expected

vitals/biomarkers/helpers.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from collections.abc import Callable
22
from pathlib import Path
3-
from typing import Any, TypedDict, TypeVar
3+
from typing import Any, Literal, TypeAlias, TypedDict, TypeVar
44

5+
import numpy as np
56
from pydantic import BaseModel
67

7-
from vitals.schemas.units import PhenoageUnits, Score2DiabetesUnits, Score2Units
8+
from vitals.schemas import phenoage, score2
89

10+
RiskCategory: TypeAlias = Literal["Low to moderate", "High", "Very high"]
911
Biomarkers = TypeVar("Biomarkers", bound=BaseModel)
10-
Units = PhenoageUnits | Score2Units | Score2DiabetesUnits
12+
Units = phenoage.Units | score2.Units | score2.UnitsDiabetes
1113

1214

1315
class ConversionInfo(TypedDict):
@@ -198,3 +200,57 @@ def extract_biomarkers_from_json(
198200
extracted_values[field_name] = value
199201

200202
return biomarker_class(**extracted_values)
203+
204+
205+
def determine_risk_category(age: float, calibrated_risk: float) -> RiskCategory:
206+
"""
207+
Determine cardiovascular risk category based on age and calibrated risk percentage.
208+
209+
Args:
210+
age: Patient's age in years
211+
calibrated_risk: Calibrated 10-year CVD risk as a percentage
212+
213+
Returns:
214+
Risk stratification category
215+
"""
216+
if age < 50:
217+
if calibrated_risk < 2.5:
218+
return "Low to moderate"
219+
elif calibrated_risk < 7.5:
220+
return "High"
221+
else:
222+
return "Very high"
223+
else: # age 50-69
224+
if calibrated_risk < 5:
225+
return "Low to moderate"
226+
elif calibrated_risk < 10:
227+
return "High"
228+
else:
229+
return "Very high"
230+
231+
232+
def apply_calibration(uncalibrated_risk: float, scale1: float, scale2: float) -> float:
233+
"""
234+
Apply regional calibration to uncalibrated risk estimate.
235+
236+
Args:
237+
uncalibrated_risk: Raw risk estimate from the Cox model
238+
scale1: First calibration scale parameter
239+
scale2: Second calibration scale parameter
240+
241+
Returns:
242+
Calibrated 10-year CVD risk as a percentage
243+
"""
244+
return float(
245+
(1 - np.exp(-np.exp(scale1 + scale2 * np.log(-np.log(1 - uncalibrated_risk)))))
246+
* 100
247+
)
248+
249+
250+
def gompertz_mortality_model(weighted_risk_score: float) -> float:
251+
params = phenoage.Gompertz()
252+
return 1 - np.exp(
253+
-np.exp(weighted_risk_score)
254+
* (np.exp(120 * params.lambda_) - 1)
255+
/ params.lambda_
256+
)
Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,12 @@
11
from pathlib import Path
22

33
import numpy as np
4-
from pydantic import BaseModel
54

65
from vitals.biomarkers import helpers
7-
from vitals.schemas.markers import PhenoageMarkers
8-
from vitals.schemas.units import PhenoageUnits
6+
from vitals.schemas.phenoage import Gompertz, LinearModel, Markers, Units
97

108

11-
class LinearModel(BaseModel):
12-
"""
13-
Coefficients used to calculate the PhenoAge from Levine et al 2018
14-
"""
15-
16-
intercept: float = -19.9067
17-
albumin: float = -0.0336
18-
creatinine: float = 0.0095
19-
glucose: float = 0.1953
20-
log_crp: float = 0.0954
21-
lymphocyte_percent: float = -0.0120
22-
mean_cell_volume: float = 0.0268
23-
red_cell_distribution_width: float = 0.3306
24-
alkaline_phosphatase: float = 0.00188
25-
white_blood_cell_count: float = 0.0554
26-
age: float = 0.0804
27-
28-
29-
class Gompertz(BaseModel):
30-
"""
31-
Parameters of the Gompertz distribution for PhenoAge computation
32-
"""
33-
34-
lambda_: float = 0.0192
35-
coef1: float = 141.50225
36-
coef2: float = -0.00553
37-
coef3: float = 0.090165
38-
39-
40-
def __gompertz_mortality_model(weighted_risk_score: float) -> float:
41-
__params = Gompertz()
42-
return 1 - np.exp(
43-
-np.exp(weighted_risk_score)
44-
* (np.exp(120 * __params.lambda_) - 1)
45-
/ __params.lambda_
46-
)
47-
48-
49-
def biological_age(filepath: str | Path) -> tuple[float, float, float]:
9+
def compute(filepath: str | Path) -> tuple[float, float, float]:
5010
"""
5111
The Phenoage score is calculated as a weighted (coefficients available in Levine et al 2018)
5212
linear combination of these variables, which was then transformed into units of years using 2 parametric
@@ -57,14 +17,14 @@ def biological_age(filepath: str | Path) -> tuple[float, float, float]:
5717
# Extract biomarkers from JSON file
5818
biomarkers = helpers.extract_biomarkers_from_json(
5919
filepath=filepath,
60-
biomarker_class=PhenoageMarkers,
61-
biomarker_units=PhenoageUnits(),
20+
biomarker_class=Markers,
21+
biomarker_units=Units(),
6222
)
6323

6424
age = biomarkers.age
6525
coef = LinearModel()
6626

67-
if isinstance(biomarkers, PhenoageMarkers):
27+
if isinstance(biomarkers, Markers):
6828
weighted_risk_score = (
6929
coef.intercept
7030
+ (coef.albumin * biomarkers.albumin)
@@ -81,7 +41,9 @@ def biological_age(filepath: str | Path) -> tuple[float, float, float]:
8141
+ (coef.white_blood_cell_count * biomarkers.white_blood_cell_count)
8242
+ (coef.age * biomarkers.age)
8343
)
84-
gompertz = __gompertz_mortality_model(weighted_risk_score=weighted_risk_score)
44+
gompertz = helpers.gompertz_mortality_model(
45+
weighted_risk_score=weighted_risk_score
46+
)
8547
model = Gompertz()
8648
pred_age = (
8749
model.coef1 + np.log(model.coef2 * np.log(1 - gompertz)) / model.coef3
@@ -90,17 +52,3 @@ def biological_age(filepath: str | Path) -> tuple[float, float, float]:
9052
return (age, pred_age, accl_age)
9153
else:
9254
raise ValueError(f"Invalid biomarker class used: {biomarkers}")
93-
94-
95-
# if __name__ == "__main__":
96-
# from pathlib import Path
97-
# input_dir = Path("tests/outputs")
98-
# output_dir = Path("tests/outputs")
99-
100-
# for input_file in input_dir.glob("*.json"):
101-
# if "patient" not in str(input_file):
102-
# continue
103-
104-
# # Update biomarker data
105-
# age, pred_age, accl_age = biological_age(str(input_file))
106-
# print(f"Chrono Age: {age} ::: Predicted Age: {pred_age} ::: Accel {accl_age}")
Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
import numpy as np
1111

1212
from vitals.biomarkers import helpers
13-
from vitals.schemas.coefficients import Score2FemaleCoefficients, Score2MaleCoefficients
14-
from vitals.schemas.core import (
13+
from vitals.schemas.score2 import (
1514
BaselineSurvival,
1615
CalibrationScales,
17-
RiskCategory,
18-
apply_calibration,
19-
determine_risk_category,
16+
FemaleCoefficientsBaseModel,
17+
MaleCoefficientsBaseModel,
18+
Markers,
19+
Units,
2020
)
21-
from vitals.schemas.markers import Score2Markers
22-
from vitals.schemas.units import Score2Units
2321

2422

25-
def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategory]:
23+
def compute(
24+
filepath: str | Path,
25+
) -> tuple[float, float, helpers.RiskCategory]:
2626
"""
2727
Calculate the 10-year cardiovascular disease risk using the SCORE2 algorithm.
2828
@@ -46,11 +46,11 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
4646
# Extract biomarkers from JSON file
4747
biomarkers = helpers.extract_biomarkers_from_json(
4848
filepath=filepath,
49-
biomarker_class=Score2Markers,
50-
biomarker_units=Score2Units(),
49+
biomarker_class=Markers,
50+
biomarker_units=Units(),
5151
)
5252

53-
if not isinstance(biomarkers, Score2Markers):
53+
if not isinstance(biomarkers, Markers):
5454
raise ValueError(f"Invalid biomarker class used: {biomarkers}")
5555

5656
age: float = biomarkers.age
@@ -73,14 +73,14 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
7373
baseline_survival_model = BaselineSurvival()
7474
calibration_scales = CalibrationScales()
7575

76-
coef: Score2MaleCoefficients | Score2FemaleCoefficients
76+
coef: MaleCoefficientsBaseModel | FemaleCoefficientsBaseModel
7777
if is_male:
78-
coef = Score2MaleCoefficients()
78+
coef = MaleCoefficientsBaseModel()
7979
baseline_survival = baseline_survival_model.male
8080
scale1 = calibration_scales.male_scale1
8181
scale2 = calibration_scales.male_scale2
8282
else:
83-
coef = Score2FemaleCoefficients()
83+
coef = FemaleCoefficientsBaseModel()
8484
baseline_survival = baseline_survival_model.female
8585
scale1 = calibration_scales.female_scale1
8686
scale2 = calibration_scales.female_scale2
@@ -101,9 +101,13 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
101101
uncalibrated_risk: float = 1 - np.power(baseline_survival, np.exp(linear_pred))
102102

103103
# Apply calibration for Belgium (Low Risk region)
104-
calibrated_risk: float = apply_calibration(uncalibrated_risk, scale1, scale2)
104+
calibrated_risk: float = helpers.apply_calibration(
105+
uncalibrated_risk, scale1, scale2
106+
)
105107

106108
# Determine risk category based on age
107-
risk_category: RiskCategory = determine_risk_category(age, calibrated_risk)
109+
risk_category: helpers.RiskCategory = helpers.determine_risk_category(
110+
age, calibrated_risk
111+
)
108112

109113
return (age, round(calibrated_risk, 2), risk_category)
Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,19 @@
1111
import numpy as np
1212

1313
from vitals.biomarkers import helpers
14-
from vitals.schemas.coefficients import (
15-
Score2DiabetesFemaleCoefficients,
16-
Score2DiabetesMaleCoefficients,
17-
)
18-
from vitals.schemas.core import (
14+
from vitals.schemas.score2 import (
1915
BaselineSurvival,
2016
CalibrationScales,
21-
RiskCategory,
22-
apply_calibration,
23-
determine_risk_category,
17+
FemaleCoefficientsDiabeticModel,
18+
MaleCoefficientsDiabeticModel,
19+
MarkersDiabetes,
20+
UnitsDiabetes,
2421
)
25-
from vitals.schemas.markers import Score2DiabetesMarkers
26-
from vitals.schemas.units import Score2DiabetesUnits
2722

2823

29-
def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategory]:
24+
def compute(
25+
filepath: str | Path,
26+
) -> tuple[float, float, helpers.RiskCategory]:
3027
"""
3128
Calculate the 10-year cardiovascular disease risk using the SCORE2-Diabetes algorithm.
3229
@@ -51,11 +48,11 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
5148
# Extract biomarkers from JSON file
5249
biomarkers = helpers.extract_biomarkers_from_json(
5350
filepath=filepath,
54-
biomarker_class=Score2DiabetesMarkers,
55-
biomarker_units=Score2DiabetesUnits(),
51+
biomarker_class=MarkersDiabetes,
52+
biomarker_units=UnitsDiabetes(),
5653
)
5754

58-
if not isinstance(biomarkers, Score2DiabetesMarkers):
55+
if not isinstance(biomarkers, MarkersDiabetes):
5956
raise ValueError(f"Invalid biomarker class used: {biomarkers}")
6057

6158
age: float = biomarkers.age
@@ -86,14 +83,14 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
8683
baseline_survival_model = BaselineSurvival()
8784
calibration_scales = CalibrationScales()
8885

89-
coef: Score2DiabetesMaleCoefficients | Score2DiabetesFemaleCoefficients
86+
coef: MaleCoefficientsDiabeticModel | FemaleCoefficientsDiabeticModel
9087
if is_male:
91-
coef = Score2DiabetesMaleCoefficients()
88+
coef = MaleCoefficientsDiabeticModel()
9289
baseline_survival = baseline_survival_model.male
9390
scale1 = calibration_scales.male_scale1
9491
scale2 = calibration_scales.male_scale2
9592
else:
96-
coef = Score2DiabetesFemaleCoefficients()
93+
coef = FemaleCoefficientsDiabeticModel()
9794
baseline_survival = baseline_survival_model.female
9895
scale1 = calibration_scales.female_scale1
9996
scale2 = calibration_scales.female_scale2
@@ -122,9 +119,13 @@ def cardiovascular_risk(filepath: str | Path) -> tuple[float, float, RiskCategor
122119
uncalibrated_risk: float = 1 - np.power(baseline_survival, np.exp(linear_pred))
123120

124121
# Apply calibration for Belgium (Low Risk region)
125-
calibrated_risk: float = apply_calibration(uncalibrated_risk, scale1, scale2)
122+
calibrated_risk: float = helpers.apply_calibration(
123+
uncalibrated_risk, scale1, scale2
124+
)
126125

127126
# Determine risk category based on age
128-
risk_category: RiskCategory = determine_risk_category(age, calibrated_risk)
127+
risk_category: helpers.RiskCategory = helpers.determine_risk_category(
128+
age, calibrated_risk
129+
)
129130

130131
return (age, round(calibrated_risk, 2), risk_category)

0 commit comments

Comments
 (0)