Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

xaiflow integrates seamlessly with MLflow to generate interactive HTML reports for SHAP analysis. Instead of static charts and images, you get rich, interactive visualizations that stakeholders can explore and understand.

Here should the video go:
[![xaiflow showcase](video/video_thumbnail.png)](https://github.com/user-attachments/assets/f508fa6f-ab0f-493d-a892-ed958331e30a)
*Click the image above to watch the feature showcase video.*

## What We're Trying to Achieve

Most ML workflows produce explanations as static images or basic charts, which creates several problems:
Expand Down Expand Up @@ -38,10 +42,9 @@ with mlflow.start_run():

# Add interactive explainable AI reports
plugin = XaiflowPlugin()
plugin.log_feature_importance_report(
plugin.log_xai_report(
feature_names=X.columns.tolist(),
shap_values=shap_values,
report_name="model_explanation.html"
)
```

Expand Down Expand Up @@ -81,7 +84,7 @@ xaiflow/
### Core Components

**MLflow Integration** (`mlflow_plugin.py`)
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.

**Report Generation** (`report_generator.py`)
The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components.
Expand Down Expand Up @@ -127,11 +130,10 @@ feature_encodings = {
'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'}
}

plugin.log_feature_importance_report(
plugin.log_xai_report(
feature_names=feature_names,
shap_values=shap_values,
feature_encodings=feature_encodings,
report_name="enhanced_report.html"
)
```

Expand Down
21 changes: 9 additions & 12 deletions examples/notebooks/auto_mpg_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "79de156f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tobias/programming/cloudexplain/ce-mlflow-extension/ce-mlflow-extension/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"c:\\programming\\cloudexplain\\xflow\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
Expand Down Expand Up @@ -42,21 +42,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "0da5d7e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded bundle.js content (218107 characters)\n",
"Saved report data to test_report_data.json\n",
"logged to test_report.html\n",
"Feature importance report logged to MLflow: reports/test_report_auto_mpg.html\n",
"Run ID: 72ea715bfe9c42bc840388933f6999a8. If you are running mlflow locally use:\n",
"Loaded bundle.js content (225719 characters)\n",
"Feature importance report logged to MLflow: reports/feature_importance_report.html\n",
"Run ID: 7521c3f260f84a5d8e038a13bc91498b. If you are running mlflow locally use:\n",
"python -m mlflow ui --port 5000\n",
"Then open http://localhost:5000/#/experiments/921177506761828334/runs/72ea715bfe9c42bc840388933f6999a8 to view the report.\n"
"Then open http://localhost:5000/#/experiments/557047036753041520/runs/7521c3f260f84a5d8e038a13bc91498b to view the report. Note: it's important to start mlflow in the directory in which you execute the notebook.\n"
]
}
],
Expand Down Expand Up @@ -91,10 +89,9 @@
" feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n",
" 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n",
" 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n",
" artifact_path = plugin.log_feature_importance_report(\n",
" artifact_path = plugin.log_xai_report(\n",
" feature_names=list(X.columns),\n",
" shap_values=shap_values,\n",
" report_name=\"test_report_auto_mpg.html\",\n",
" feature_encodings=feature_encodings\n",
" )\n",
" run_id = mlflow.active_run().info.run_id\n",
Expand All @@ -119,7 +116,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,
Expand Down
3 changes: 1 addition & 2 deletions examples/notebooks/azure_ml_auto_mpg_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@
" feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n",
" 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n",
" 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n",
" artifact_path = plugin.log_feature_importance_report(\n",
" artifact_path = plugin.log_xai_report(\n",
" feature_names=list(X.columns),\n",
" shap_values=shap_values,\n",
" report_name=\"test_report_auto_mpg.html\",\n",
" feature_encodings=feature_encodings\n",
" )\n",
" run_id = mlflow.active_run().info.run_id\n",
Expand Down
7 changes: 4 additions & 3 deletions examples/scripts/auto_mpg_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},
'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},
'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}
artifact_path = plugin.log_feature_importance_report(
artifact_path = plugin.log_xai_report(
feature_names=list(X.columns),
shap_values=shap_values,
report_name="test_report_auto_mpg.html",
feature_encodings=feature_encodings
feature_encodings=feature_encodings,
# assign each sample to a custom group label
group_labels=["Custom Group " + str(i % 4) for i in range(len(X))],
)
run_id = mlflow.active_run().info.run_id
print(f"Run ID: {run_id}. If you are running mlflow locally use:\npython -m mlflow ui --port 5000\nThen open http://localhost:5000/#/experiments/{mlflow.get_experiment_by_name(experiment_name).experiment_id}/runs/{run_id} to view the report.",
Expand Down
2 changes: 1 addition & 1 deletion src/xaiflow/mlflow_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self.template_dir = os.path.join(os.path.dirname(__file__), 'templates')
self.env = Environment(loader=FileSystemLoader(self.template_dir))

def log_feature_importance_report(
def log_xai_report(
self,
feature_names: List[str],
shap_values: Explanation,
Expand Down
6 changes: 3 additions & 3 deletions src/xflow.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ with mlflow.start_run():

# Add interactive explainable AI reports
plugin = CEMLflowPlugin()
plugin.log_feature_importance_report(
plugin.log_xai_report(
feature_names=X.columns.tolist(),
shap_values=shap_values,
report_name="model_explanation.html"
Expand Down Expand Up @@ -121,7 +121,7 @@ xaiflow/
### Core Components

**MLflow Integration** (`mlflow_plugin.py`)
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.

**Report Generation** (`report_generator.py`)
The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components.
Expand Down Expand Up @@ -167,7 +167,7 @@ feature_encodings = {
'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'}
}

plugin.log_feature_importance_report(
plugin.log_xai_report(
feature_names=feature_names,
shap_values=shap_values,
feature_encodings=feature_encodings,
Expand Down
11 changes: 4 additions & 7 deletions tests/test_mlflow_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from xgboost import XGBClassifier
from catboost import CatBoostClassifier
from sklearn.preprocessing import LabelEncoder
import shap
from typing import Callable
Expand Down Expand Up @@ -187,7 +185,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
mocker.patch("mlflow.log_artifact")

with mlflow.start_run(run_name="auto_mpg_test"):
plugin.log_feature_importance_report(
plugin.log_xai_report(
shap_values=shap_values,
feature_encodings=feature_encodings,
feature_names=list(X.columns),
Expand Down Expand Up @@ -263,7 +261,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
mocker.patch("mlflow.log_artifact")

with mlflow.start_run(run_name="auto_mpg_test"):
plugin.log_feature_importance_report(
plugin.log_xai_report(
shap_values=shap_values,
feature_encodings=feature_encodings,
feature_names=list(X.columns),
Expand Down Expand Up @@ -319,11 +317,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
mocker.patch("mlflow.log_artifact")

with mlflow.start_run(run_name="auto_mpg_test"):
plugin.log_feature_importance_report(
plugin.log_xai_report(
shap_values=shap_values,
feature_encodings=feature_encodings,
feature_names=list(X.columns),
group_labels=["Group 1", "Group 2", "Group 3", "Group 4"] * int(len(shap_values) / 4) # Example group labels
)
html_content_click_test(Path(output_path))
# return html_content
html_content_click_test(Path(output_path))
Binary file added video/video_thumbnail.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading