Skip to content

Commit 69ed489

Browse files
authored
Add heterogeneity indices, printing indices/legend, fix index.html. (#56)
1 parent 21b8623 commit 69ed489

10 files changed

Lines changed: 852 additions & 254 deletions

docs/notebooks/structural_reliability.ipynb

Lines changed: 254 additions & 236 deletions
Large diffs are not rendered by default.

panel/index.html

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,34 +203,36 @@
203203
<fast-text-field id="search-input" placeholder="search" onInput="hideCards(event.target.value)"></fast-text-field>
204204
</section>
205205

206-
<section id="cards">
206+
<section id="cards">
207207
<ul class="cards-grid">
208+
<!-- Sampling card moved to first position -->
208209
<li class="card">
209-
<a class="card-link" href="./simdec_app.html" id="simdec_app">
210+
<a class="card-link" href="./sampling.html" id="sampling">
210211
<fast-card class="gallery-item">
211-
<object data="_static/thumbnails/simdec_app.png" type="image/png">
212+
<object data="_static/thumbnails/sampling.png" type="image/png">
212213
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
213214
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
214215
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
215216
</svg>
216217
</object>
217218
<div class="card-content">
218-
<h2 class="card-header">SimDec App</h2>
219+
<h2 class="card-header">Sampling</h2>
219220
</div>
220221
</fast-card>
221222
</a>
222223
</li>
224+
<!-- SimDec App card moved to second position -->
223225
<li class="card">
224-
<a class="card-link" href="./sampling.html" id="sampling">
226+
<a class="card-link" href="./simdec_app.html" id="simdec_app">
225227
<fast-card class="gallery-item">
226-
<object data="_static/thumbnails/sampling.png" type="image/png">
228+
<object data="_static/thumbnails/simdec_app.png" type="image/png">
227229
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
228230
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
229231
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
230232
</svg>
231233
</object>
232234
<div class="card-content">
233-
<h2 class="card-header">Sampling</h2>
235+
<h2 class="card-header">SimDec App</h2>
234236
</div>
235237
</fast-card>
236238
</a>

panel/simdec_app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):
170170

171171

172172
def explained_variance_80(sensitivity_indices_table):
173-
si = sensitivity_indices_table.value["Indices"]
174-
pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
173+
df = sensitivity_indices_table.value
174+
df = df[df["Inputs"] != "Sum of Indices"]
175+
si = df["Indices"].values
176+
target = 0.8 * np.sum(si)
177+
pos_80 = bisect.bisect_right(np.cumsum(si), target)
175178

176179
# pos_80 = max(2, pos_80)
177180
# pos_80 = min(len(si), pos_80)
@@ -225,9 +228,8 @@ def create_color_pickers(states, colors):
225228
@pn.cache
226229
def palette_(states: list[list[str]], colors_picked: list[list[float]]):
227230
cmaps = [single_color_to_colormap(color_picked) for color_picked in colors_picked]
228-
# Reverse order as in figures high values take the first colors
229231
states = [len(states_) for states_ in states]
230-
return sd.palette(states, cmaps=cmaps[::-1])
232+
return sd.palette(states, cmaps=cmaps)
231233

232234

233235
@pn.cache

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ dashboard = [
4141
"cryptography",
4242
]
4343

44+
display = [
45+
"ipython>=9.1"
46+
]
47+
4448
test = [
4549
"pytest",
4650
"pytest-cov",
@@ -55,7 +59,7 @@ doc = [
5559
]
5660

5761
dev = [
58-
"simdec[doc,test,dashboard]",
62+
"simdec[doc,test,dashboard, display]",
5963
"watchfiles",
6064
"pre-commit",
6165
]

src/simdec/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""SimDec main namespace."""
22
from simdec.decomposition import *
3+
from simdec.heterogeneity_indices import *
34
from simdec.sensitivity_indices import *
45
from simdec.visualization import *
56

@@ -11,4 +12,5 @@
1112
"two_output_visualization",
1213
"tableau",
1314
"palette",
15+
"heterogeneity_indices",
1416
]
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
from dataclasses import dataclass
2+
import logging
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
8+
import simdec as sd
9+
10+
logger = logging.getLogger(__name__)
11+
12+
__all__ = ["heterogeneity_indices", "plot_heterogeneity"]
13+
14+
15+
@dataclass
16+
class HeterogeneityResult:
17+
summary: pd.DataFrame
18+
regional_profiles: pd.DataFrame
19+
split_name: str
20+
21+
22+
def heterogeneity_indices(
23+
output: pd.Series,
24+
inputs: pd.DataFrame,
25+
split_variable: str | pd.Series,
26+
n_subdivisions: int | None = None,
27+
plot: bool = False,
28+
) -> HeterogeneityResult:
29+
"""Heterogeneity indices.
30+
31+
Compute sensitivity-based heterogeneity across subdivisions
32+
of a variable.
33+
34+
Parameters
35+
----------
36+
output : pd.Series
37+
Model output vector.
38+
inputs : pd.DataFrame
39+
Input/feature matrix.
40+
split_variable : str or pd.Series
41+
Variable to split on. If string, must be a column in 'inputs'.
42+
n_subdivisions : int, optional
43+
Number of regions for continuous variables. Defaults to 4.
44+
plot : bool, default False
45+
If True, displays a stacked bar chart of regional sensitivity profiles
46+
by calling :func:`plot_heterogeneity`. The chart shows variance
47+
contributions of each input across subdivisions of ``split_variable``,
48+
ranked by global sensitivity indices. To capture the returned
49+
``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
50+
directly on the result instead.
51+
52+
Returns
53+
-------
54+
res : HeterogeneityResult
55+
An object with attributes:
56+
57+
summary : DataFrame
58+
A summary of calculated heterogeneity indices.
59+
regional_profiles : DataFrame
60+
Regional sensitivity indices for each input across subdivisions.
61+
split_name : str
62+
The name of the variable used to split the data.
63+
64+
"""
65+
y = pd.Series(output).reset_index(drop=True)
66+
X = pd.DataFrame(inputs).reset_index(drop=True)
67+
68+
if isinstance(split_variable, str):
69+
if split_variable not in X.columns:
70+
raise ValueError(f"'{split_variable}' not found in inputs.")
71+
z = X[split_variable].reset_index(drop=True)
72+
split_name = split_variable
73+
else:
74+
z = pd.Series(split_variable).reset_index(drop=True)
75+
split_name = getattr(split_variable, "name", "split_variable")
76+
77+
unique_vals = z.dropna().unique()
78+
n_unique = len(unique_vals)
79+
80+
# Determine if variable is categorical/binary
81+
is_categorical = (
82+
isinstance(z.dtype, pd.CategoricalDtype)
83+
or pd.api.types.is_object_dtype(z)
84+
or pd.api.types.is_string_dtype(z)
85+
or pd.api.types.is_bool_dtype(z)
86+
or n_unique <= 2
87+
)
88+
89+
if is_categorical:
90+
regions = z.astype("category")
91+
else:
92+
q = n_subdivisions if n_subdivisions is not None else 4
93+
try:
94+
regions = pd.qcut(z, q=q, duplicates="drop")
95+
except ValueError as e:
96+
raise ValueError(
97+
f"Failed to bin '{split_name}' into {q} quantiles: {e}"
98+
) from e
99+
100+
regional_profiles = []
101+
skipped = []
102+
103+
for region in regions.cat.categories:
104+
mask = regions == region
105+
n_in_region = mask.sum()
106+
107+
if n_in_region < 10:
108+
# Need enough samples for meaningful sensitivity indices
109+
skipped.append((region, n_in_region, "too few samples (< 10)"))
110+
continue
111+
112+
X_sub = X.loc[mask]
113+
y_sub = y.loc[mask]
114+
115+
# Skip if output has zero or near-zero variance in this region
116+
if y_sub.var() < 1e-12:
117+
skipped.append((region, n_in_region, "output variance ≈ 0"))
118+
continue
119+
120+
try:
121+
res = sd.sensitivity_indices(inputs=X_sub, output=y_sub)
122+
si_vals = np.asarray(res.si).ravel()
123+
124+
# Guard against NaN/Inf from degenerate sensitivity computation
125+
if not np.all(np.isfinite(si_vals)):
126+
skipped.append((region, n_in_region, "non-finite SI values"))
127+
continue
128+
129+
si_region = pd.Series(si_vals, index=X.columns, name=region)
130+
regional_profiles.append(si_region)
131+
132+
except Exception as e:
133+
skipped.append((region, n_in_region, f"exception: {e}"))
134+
continue
135+
136+
if skipped:
137+
logger.info("Skipped %d region(s) of '%s':", len(skipped), split_name)
138+
for reg, n, reason in skipped:
139+
logger.info(" - region=%r, n=%d, reason=%s", reg, n, reason)
140+
141+
if len(regional_profiles) < 2:
142+
total_regions = len(regions.cat.categories)
143+
valid = len(regional_profiles)
144+
raise ValueError(
145+
f"Not enough valid subdivisions to compute heterogeneity: "
146+
f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
147+
f"Skipped regions:\n"
148+
"\n".join(f" {r!r}: n={n}, {reason} " for r, n, reason in skipped),
149+
"\n\nTry: (1) reducing n_subdivisions, "
150+
"(2) using a different split_variable, or "
151+
"(3) ensuring more samples per region.",
152+
)
153+
154+
regional_si = pd.concat(regional_profiles, axis=1)
155+
156+
res_global = sd.sensitivity_indices(inputs=X, output=y)
157+
overall_si = pd.Series(
158+
np.asarray(res_global.si).ravel(),
159+
index=X.columns,
160+
name="Overall_SI",
161+
)
162+
163+
# Heterogeneity = 2 × population std dev across regions
164+
hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
165+
total_hetero = hetero_scores.mean()
166+
167+
hetero_col_name = f"Heterogeneity (across {split_name})"
168+
summary = pd.DataFrame(
169+
{"Overall_SI": overall_si, hetero_col_name: hetero_scores}
170+
).sort_values(by=hetero_col_name, ascending=False)
171+
summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]
172+
173+
result = HeterogeneityResult(summary, regional_si, split_name)
174+
175+
if plot:
176+
plot_heterogeneity(result)
177+
178+
return result
179+
180+
181+
def plot_heterogeneity(result: HeterogeneityResult, ax: plt.Axes = None) -> plt.Axes:
182+
"""Plot regional sensitivity profiles.
183+
184+
Parameters
185+
----------
186+
result : HeterogeneityResult
187+
The result object from heterogeneity_indices.
188+
ax : matplotlib.axes.Axes, optional
189+
Existing axes to plot on.
190+
191+
Returns
192+
-------
193+
ax : matplotlib.axes.Axes
194+
The axes with the plot.
195+
196+
"""
197+
summary = result.summary
198+
regional_si = result.regional_profiles
199+
split_name = result.split_name
200+
201+
hetero_col_name = [c for c in summary.columns if "Heterogeneity" in c][0]
202+
total_hetero = summary.loc["SUM / TOTAL", hetero_col_name]
203+
204+
plot_order = summary.index[summary.index != "SUM / TOTAL"]
205+
plot_order = (
206+
summary.loc[plot_order].sort_values(by="Overall_SI", ascending=False).index
207+
)
208+
209+
cmap = plt.colormaps["terrain"]
210+
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(regional_si.index))]
211+
212+
data_to_plot = regional_si.loc[plot_order].T
213+
214+
if ax is None:
215+
_, ax = plt.subplots(figsize=(10, 6))
216+
217+
data_to_plot.plot(
218+
kind="bar",
219+
stacked=True,
220+
ax=ax,
221+
color=colors,
222+
edgecolor="white",
223+
width=0.8,
224+
)
225+
226+
ax.set_title(
227+
f"Sensitivity Profiles across {split_name}\n"
228+
f"(Total Heterogeneity: {total_hetero:.3f})",
229+
fontsize=10,
230+
)
231+
232+
ax.set_ylabel("Variance Contribution", fontsize=8)
233+
ax.set_xlabel(f"Regions of {split_name}", fontsize=8)
234+
235+
ax.legend(
236+
title="Inputs (Ranked by Global SI)",
237+
bbox_to_anchor=(1.05, 1),
238+
loc="upper left",
239+
)
240+
241+
ax.tick_params(axis="x", labelrotation=45)
242+
ax.grid(axis="y", linestyle="--", alpha=0.7)
243+
244+
if plt.get_backend().lower() != "agg":
245+
plt.tight_layout()
246+
247+
return ax

0 commit comments

Comments
 (0)