Skip to main content

SHAP Integration

MLflow's built-in SHAP integration provides automatic model explanations and feature importance analysis during evaluation. SHAP (SHapley Additive exPlanations) values help you understand what drives your model's predictions, making your ML models more interpretable and trustworthy.

Quick Start: Automatic SHAP Explanations​

Enable SHAP explanations during model evaluation with a simple configuration:

import mlflow
import xgboost as xgb
import shap
from sklearn.model_selection import train_test_split
from mlflow.models import infer_signature

# Load the UCI Adult Dataset
X, y = shap.datasets.adult()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

# Train model
model = xgb.XGBClassifier().fit(X_train, y_train)

# Create evaluation dataset
eval_data = X_test.copy()
eval_data["label"] = y_test

with mlflow.start_run():
# Log model
signature = infer_signature(X_test, model.predict(X_test))
mlflow.sklearn.log_model(model, name="model", signature=signature)
model_uri = mlflow.get_artifact_uri("model")

# Evaluate with SHAP explanations enabled
result = mlflow.evaluate(
model_uri,
eval_data,
targets="label",
model_type="classifier",
evaluators=["default"],
evaluator_config={"log_explainer": True}, # Enable SHAP logging
)

print("SHAP artifacts generated:")
for artifact_name in result.artifacts:
if "shap" in artifact_name.lower():
print(f" - {artifact_name}")

This automatically generates:

  • Feature importance plots showing which features matter most
  • SHAP summary plots displaying feature impact distributions
  • SHAP explainer model saved for future use on new data
  • Individual prediction explanations for sample predictions

Understanding SHAP Outputs​

Feature Importance Visualization​

MLflow automatically creates SHAP-based feature importance charts:

# The evaluation generates several SHAP visualizations:
# - shap_feature_importance_plot.png: Bar chart of average feature importance
# - shap_summary_plot.png: Dot plot showing feature impact distribution
# - explainer model: Saved SHAP explainer for generating new explanations

# Access the results
print(f"Model accuracy: {result.metrics['accuracy_score']:.3f}")
print("Generated SHAP artifacts:")
for name, path in result.artifacts.items():
if "shap" in name:
print(f" {name}: {path}")

Shap feature importances logged to MLflow when Shap evaluation is enabled

Configuring SHAP Explanations​

Control how SHAP explanations are generated:

# Advanced SHAP configuration
shap_config = {
"log_explainer": True, # Save the explainer model
"explainer_type": "exact", # Use exact SHAP values (slower but precise)
"max_error_examples": 100, # Number of error cases to explain
"log_model_explanations": True, # Log individual prediction explanations
}

result = mlflow.evaluate(
model_uri,
eval_data,
targets="label",
model_type="classifier",
evaluators=["default"],
evaluator_config=shap_config,
)
Configuration Options

Explainer Types​

  • "exact": Precise SHAP values using the exact algorithm (slower)
  • "permutation": Permutation-based explanations (faster, approximate)
  • "partition": Tree-based explanations for tree models

Output Control​

  • log_explainer: Whether to save the SHAP explainer as a model
  • max_error_examples: Number of misclassified examples to explain in detail
  • log_model_explanations: Whether to log explanations for individual predictions

Working with SHAP Explainers​

Once logged, you can load and use SHAP explainers on new data:

# Load the saved SHAP explainer
run_id = "your_run_id_here"
explainer_uri = f"runs:/{run_id}/explainer"

# Load explainer
explainer = mlflow.pyfunc.load_model(explainer_uri)

# Generate explanations for new data
new_data = X_test[:10] # Example: first 10 samples
explanations = explainer.predict(new_data)

print(f"Generated explanations shape: {explanations.shape}")
print(f"Feature contributions for first prediction: {explanations[0]}")

# The explanations array contains SHAP values for each feature and prediction

Interpreting SHAP Values​

def interpret_shap_explanations(explanations, feature_names, sample_idx=0):
"""Interpret SHAP explanations for a specific prediction."""

sample_explanations = explanations[sample_idx]

# Sort features by absolute importance
feature_importance = list(zip(feature_names, sample_explanations))
feature_importance.sort(key=lambda x: abs(x[1]), reverse=True)

print(f"SHAP explanation for sample {sample_idx}:")
print("Top 5 most important features:")

for i, (feature, importance) in enumerate(feature_importance[:5]):
direction = "increases" if importance > 0 else "decreases"
print(f" {i+1}. {feature}: {importance:.3f} ({direction} prediction)")

return feature_importance


# Usage
feature_names = X_test.columns.tolist()
top_features = interpret_shap_explanations(explanations, feature_names, sample_idx=0)

Production SHAP Workflows​

Generate explanations for large datasets efficiently:

def batch_shap_explanations(model_uri, data_path, batch_size=1000):
"""Generate SHAP explanations for large datasets in batches."""

import pandas as pd

with mlflow.start_run(run_name="Batch_SHAP_Generation"):
# Load model and create explainer
model = mlflow.pyfunc.load_model(model_uri)

# Process data in batches
batch_results = []
total_samples = 0

for chunk_idx, data_chunk in enumerate(
pd.read_parquet(data_path, chunksize=batch_size)
):
# Generate explanations for batch
explanations = generate_explanations(model, data_chunk)

# Store results
batch_results.append(
{
"batch_idx": chunk_idx,
"explanations": explanations,
"sample_count": len(data_chunk),
}
)

total_samples += len(data_chunk)

# Log progress
if chunk_idx % 10 == 0:
print(f"Processed {total_samples} samples...")

# Log batch processing summary
mlflow.log_params(
{
"total_batches": len(batch_results),
"total_samples": total_samples,
"batch_size": batch_size,
}
)

return batch_results


def generate_explanations(model, data):
"""Generate SHAP explanations (placeholder - implement based on your model type)."""
# This would contain your actual SHAP explanation logic
# returning mock data for example
return np.random.random((len(data), data.shape[1]))

Best Practices and Use Cases​

When to Use SHAP Integration​

SHAP integration provides the most value in these scenarios:

High Interpretability Requirements - Healthcare and medical diagnosis systems, financial services (credit scoring, loan approval), legal and compliance applications, hiring and HR decision systems, and fraud detection and risk assessment.

Complex Model Types - XGBoost, Random Forest, and other ensemble methods, neural networks and deep learning models, custom ensemble approaches, and any model where feature relationships are non-obvious.

Regulatory and Compliance Needs - Models requiring explainability for regulatory approval, systems where decisions must be justified to stakeholders, applications where bias detection is important, and audit trails requiring detailed decision explanations.

Performance Considerations​

Dataset Size Guidelines:

  • Small datasets (< 1,000 samples): Use exact SHAP methods for precision
  • Medium datasets (1,000 - 50,000 samples): Standard SHAP analysis works well
  • Large datasets (50,000+ samples): Consider sampling or approximate methods
  • Very large datasets (100,000+ samples): Use batch processing with sampling

Memory Management:

  • Process explanations in batches for large datasets
  • Use approximate SHAP methods when exact precision isn't required
  • Clear intermediate results to manage memory usage
  • Consider model-specific optimizations (e.g., TreeExplainer for tree models)

Integration with MLflow Model Registry​

SHAP explainers can be stored and versioned alongside your models:

def register_model_with_explainer(model_uri, explainer_uri, model_name):
"""Register both model and explainer in MLflow Model Registry."""

from mlflow.tracking import MlflowClient

client = MlflowClient()

# Register the main model
model_version = mlflow.register_model(model_uri, model_name)

# Register the explainer as a separate model
explainer_name = f"{model_name}_explainer"
explainer_version = mlflow.register_model(explainer_uri, explainer_name)

# Add tags to link them
client.set_model_version_tag(
model_name, model_version.version, "explainer_model", explainer_name
)

client.set_model_version_tag(
explainer_name, explainer_version.version, "base_model", model_name
)

return model_version, explainer_version


# Usage
# model_ver, explainer_ver = register_model_with_explainer(
# model_uri, explainer_uri, "my_classifier"
# )

Conclusion​

MLflow's SHAP integration provides automatic model interpretability without additional setup complexity. By enabling SHAP explanations during evaluation, you gain valuable insights into feature importance and model behavior that are essential for building trustworthy ML systems.

Key benefits include:

  • Automatic Generation: SHAP explanations created during standard model evaluation
  • Production Ready: Saved explainers can generate explanations for new data
  • Visual Insights: Automatic generation of feature importance and summary plots
  • Model Comparison: Compare interpretability across different model types

SHAP integration is particularly valuable for regulated industries, high-stakes decisions, and complex models where understanding "why" is as important as "what" the model predicts.