Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,19 @@ Built on top of [MLflow](https://mlflow.org/) for experiment tracking, [SHAP](ht

---

Made by the [CloudExplain Team](https://cloudexplain.eu)
Made by the [CloudExplain Team](https://cloudexplain.eu)

## Development Notes

- **All styling must go in `report.html`**: Do not use inline styles or Svelte `<style>` blocks for the main report UI. This is required because styles are not reliably handed over or injected from Svelte to the final HTML report. Always update or add CSS in `src/xaiflow/templates/report.html` for any UI/UX changes.
- **Always rebuild the frontend bundle before running Python tests**: The Svelte/JS bundle (`bundle.js`) must be up-to-date for the Python tests to work correctly. Before running any Python tests (e.g., Playwright or integration tests), always run:

```bash
make build && python -m pytest tests/test_mlflow_plugin.py
```
or, for all tests:
```bash
make build && python -m pytest
```
- **Frontend changes require a rebuild**: Any change to Svelte components or frontend logic requires a new build of `bundle.js` to be reflected in the generated reports and tests.
- **UI/UX review**: When making UI changes, always check the result in the browser and ensure the layout matches the design intent. Use only relative units (rem, em, %) for all sizing and spacing in CSS.
9 changes: 9 additions & 0 deletions src/xaiflow/mlflow_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def log_feature_importance_report(
shap_values: Explanation,
feature_encodings: Optional[Dict[str, Dict[int, str]]] = None,
importance_values: List[float] | np.ndarray = None,
group_labels: Optional[List[str]] = None,
run_id: Optional[str] = None,
artifact_path: str = "reports",
report_name: str = "feature_importance_report.html",
Expand All @@ -44,6 +45,7 @@ def log_feature_importance_report(
feature_names: List of feature names
importance_values: List of importance values corresponding to features
shap_values: Optional SHAP values matrix (samples x features)
group_labels: Optional list of group labels for each sample
run_id: MLflow run ID (uses active run if None)
artifact_path: Path within MLflow artifacts to store the report
report_name: Name of the HTML report file
Expand Down Expand Up @@ -74,6 +76,10 @@ def log_feature_importance_report(
shap_values = shap_values[..., -1]
base_values = float(base_values[-1])

if group_labels is not None:
if len(group_labels) != shap_values.shape[0]:
raise ValueError("group_labels length must match the number of samples in shap_values.")

# Use active run if no run_id provided
if run_id is None:
active_run = mlflow.active_run()
Expand Down Expand Up @@ -105,6 +111,7 @@ def log_feature_importance_report(
html_content = self._generate_html_content(
importance_data=importance_data,
shap_values=shap_values,
group_labels=group_labels or [], # Default to empty list if None
feature_values=feature_values,
base_values=base_values,
feature_encodings=feature_encodings,
Expand Down Expand Up @@ -143,6 +150,7 @@ def _generate_html_content(
importance_data: Dict[str, Any],
shap_values: List[List[float]],
feature_values: List[float] = None,
group_labels: List[str] = None,
base_values: List[float] = None,
feature_encodings: Optional[Dict[str, Dict[int, str]]] = None,
feature_names: List[str] = None
Expand Down Expand Up @@ -202,6 +210,7 @@ def _generate_html_content(
timestamp=current_time,
importance_data=importance_data, # Pass as Python dict
shap_values=shap_values, # Pass as Python list
group_labels=group_labels or [], # Pass as Python list or empty list
feature_values=feature_values, # Pass as Python list or None
base_values=base_values or [0] * 10, # Todo: fix this once we hand over numpy arrays
feature_encodings=feature_encodings or {}, # Pass as optional dict
Expand Down
10 changes: 5 additions & 5 deletions src/xaiflow/templates/assets/bundle.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/xaiflow/templates/assets/bundle.js.map

Large diffs are not rendered by default.

187 changes: 95 additions & 92 deletions src/xaiflow/templates/components/ChartManager.svelte
Original file line number Diff line number Diff line change
@@ -1,25 +1,62 @@
<script lang="ts">
import ImportanceChart2 from './ImportanceChart2.svelte';
import ScatterShapValues from './ScatterShapValues.svelte';
import DeepDiveManager from './DeepDiveManager.svelte';

// Props using Svelte 5 runes
interface Props {
importanceData: { feature_name: string; importance: number }[];
shapValues: number[][];
featureValues: number[][];
featureEncodings?: { [key: string]: any }[]; // For feature value mapping
baseValues: number[] | number; // Base values for SHAP calculations
featureNames?: string[]; // Optional prop for feature names
isHigherOutputBetter?: boolean; // Optional prop to determine if higher output is better
groupLabels: string[]; // Optional prop for group labels
}

let { importanceData, shapValues, featureValues, featureEncodings = {} }: Props = $props();
let { importanceData,
shapValues,
featureValues,
featureEncodings = {},
baseValues,
featureNames,
isHigherOutputBetter,
groupLabels,
}: Props = $props();

// Reactive state for selected label using $state
let selectedLabel: string | null = $state(null);
let showDeepDive = $state(false);
let selectedGroup: string | null = $state(null);
// Compute unique group labels
let uniqueGroups: string[] = $derived(Array.from(new Set(groupLabels || [])));
console.log('ChartManager: Loaded with props:', {
importanceData,
shapValues,
featureValues,
featureEncodings,
baseValues,
featureNames,
isHigherOutputBetter,
groupLabels
});

// Compute selectedShapValues based on selectedGroup
let selectedShapValues = $derived((selectedGroup && selectedGroup !== "" && selectedGroup !== "All")
? shapValues.filter((_, idx) => groupLabels[idx] === selectedGroup)
: shapValues);

let selectedFeatureValues = $derived((selectedGroup && selectedGroup !== "" && selectedGroup !== "All")
? featureValues.filter((_, idx) => groupLabels[idx] === selectedGroup)
: featureValues);
console.log('ChartManager: selectedShapValues computed:', selectedShapValues);

console.log("ChartManager", importanceData);
console.log('ChartManager: 1/4 command in file');
let featureNames = $derived(
importanceData.map(item => item.feature_name)
);
// let featureNames = $derived(
// importanceData.map(item => item.feature_name)
// );
console.log('ChartManager: 2/4 command in file');

console.log('ChartManager: called');
Expand All @@ -42,94 +79,60 @@
</script>

<div class="chart-manager">
<div class="charts-row">
<div class="chart-section">
<h3>Feature Importance Chart</h3>
<div class="chart-container">
<ImportanceChart2
data={importanceData}
bind:selectedLabel={selectedLabel}
on:labelSelected={handleLabelSelection}
/>
</div>
<div style="display: flex; gap: 1.5rem; align-items: center; margin-bottom: 1.5rem; justify-content: space-between;">
<div style="display: flex; gap: 1.5rem; align-items: center;">
<button type="button" on:click={() => showDeepDive = false} class:selected={!showDeepDive}>Charts</button>
<button id="deepdive-button" type="button" on:click={() => showDeepDive = true} class:selected={showDeepDive}>Deep Dive</button>
</div>

<div class="chart-section">
<h3>SHAP Values</h3>
<div class="chart-container">
<ScatterShapValues
shapValues={shapValues}
featureValues={featureValues}
bind:selectedFeatureIndex={selectedFeatureIndex}
selectedFeature={selectedLabel}
bind:selectedLabel={selectedLabel}
isHigherOutputBetter={true}
featureEncodings={featureEncodings}
/>
{#if uniqueGroups.length > 0}
<div style="margin-left: auto;">
<label for="group-dropdown" style="margin-right: 0.5em; font-size: 1em;">Group:</label>
<select id="group-dropdown" bind:value={selectedGroup} on:change={(e) => selectedGroup = e.target.value} style="font-size: 1em; padding: 0.3em 0.7em;">
<option value="">All</option>
{#each uniqueGroups as group}
<option value={group}>{group}</option>
{/each}
</select>
</div>
</div>
{/if}
</div>
</div>

<style>
.chart-manager {
width: 100%;
}

.charts-row {
display: flex;
flex-direction: row;
gap: 20px;
margin-bottom: 30px;
width: 100%;
align-items: stretch;
}

.chart-section {
flex: 1;
min-width: 0; /* Allows flex items to shrink below their natural width */
width: 50%;
max-width: 50%;
}

.chart-section h3 {
margin-bottom: 15px;
color: #ff0000; /* Changed to red to test props propagation */
text-align: center;
font-weight: bold;
}

.chart-container {
height: 500px;
width: 100%;
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 10px;
background-color: #fafafa;
box-sizing: border-box;
}

.selected-info {
background-color: #f0f8ff;
padding: 15px;
border-radius: 5px;
border-left: 4px solid #007acc;
margin-top: 20px;
}

.selected-info p {
margin: 0;
font-size: 16px;
}

/* Responsive design for smaller screens */
@media (max-width: 768px) {
.charts-row {
flex-direction: column;
}

.chart-container {
height: 400px;
}
}
</style>
{#if !showDeepDive}
<div class="charts-row">
<div class="chart-section">
<h3>Feature Importance Chart</h3>
<div class="chart-container">
<ImportanceChart2
data={importanceData}
bind:selectedLabel={selectedLabel}
on:labelSelected={handleLabelSelection}
/>
</div>
</div>

<div class="chart-section">
<h3>SHAP Values</h3>
<div class="chart-container">
<ScatterShapValues
shapValues={selectedShapValues}
featureValues={selectedFeatureValues}
bind:selectedFeatureIndex={selectedFeatureIndex}
bind:selectedFeature={selectedLabel}
isHigherOutputBetter={true}
featureEncodings={featureEncodings}
/>
</div>
</div>
</div>
{:else}
<DeepDiveManager
shapValues={selectedShapValues}
featureValues={selectedFeatureValues}
selectedFeatureIndex={selectedFeatureIndex}
selectedFeature={selectedLabel}
baseValues={baseValues}
featureEncodings={featureEncodings}
isHigherOutputBetter={true}
featureNames={featureNames}
/>
{/if}
</div>
16 changes: 12 additions & 4 deletions src/xaiflow/templates/components/DeepDiveChart.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
let pointBackgroundColor;
let cumulativeValues: number;
let maxCumulativeValue: number;
let minCumulativeValue: number;

function getScreenSizeFlags() {
const width = window.innerWidth;
Expand Down Expand Up @@ -115,13 +116,14 @@
console.log("DeepDiveChart: NEW 2 Updating chart with new data", cumulativeValues);
chart.data.datasets[0].data = cumulativeValues;
maxCumulativeValue = Math.max(...cumulativeValues.map(d => d[1]));
minCumulativeValue = Math.min(...cumulativeValues.map(d => d[0]));
chart.data.datasets[0].backgroundColor = pointBackgroundColor;
// Dynamically update y-axis min and max
console.log("DeepDiveChart: Updating chart with new data", cumulativeValues, pointBackgroundColor);
if (chart.options.scales?.y) {
console.log("DeepDiveChart: Updating y-axis min and max to ", Math.floor(minOfData), Math.ceil(maxOfData * 1.05));
chart.options.scales.y.min = Math.floor(minOfData);
chart.options.scales.y.max = Math.ceil(maxOfData * 1.05);
console.log("DeepDiveChart: Updating y-axis min and max to ", Math.floor(minOfData), Math.ceil(maxOfData * 1.05), maxCumulativeValue, minCumulativeValue);
chart.options.scales.y.min = Math.floor(minCumulativeValue * 0.95);
chart.options.scales.y.max = Math.ceil(maxCumulativeValue * 1.05);
}
// Update x-axis rotation based on screen size
if (chart.options.scales?.x?.ticks) {
Expand Down Expand Up @@ -411,7 +413,13 @@
});
</script>

<canvas id="deepdive-canvas" bind:this={chartCanvas}></canvas>
<div style="position: relative; width: 100%; height: 100%;">
<canvas id="deepdive-canvas" bind:this={chartCanvas}></canvas>
<div class="deepdive-prediction-box">
<div><strong>prediction:</strong> {Math.round((base_value + singleShapValues.reduce((a, b) => a + b, 0)) * 100) / 100}</div>
<div><strong>baseline:</strong> {Math.round(base_value * 100) / 100}</div>
</div>
</div>

<style>
canvas {
Expand Down
Loading
Loading