Spark MLlib with MLflow
In this comprehensive guide, we'll walk you through how to use Spark MLlib with MLflow for experiment tracking, model management, and production deployment. We'll cover basic model logging, pipeline tracking, and deployment patterns that will get you productive quickly with distributed machine learning.
Quick Start with Basic Model Loggingβ
The simplest way to get started is by logging your Spark MLlib models directly to MLflow:
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()
# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)
# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)
# Log the entire pipeline
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark-pipeline"
)
# Log parameters manually
mlflow.log_params(
{
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
}
)
print(f"Model logged with URI: {model_info.model_uri}")
This simple example automatically logs:
- The complete Spark ML pipeline with all stages
- Model parameters from each pipeline stage
- The trained model in both Spark native and PyFunc formats
Model Formats and Loadingβ
- Native Spark Format
- PyFunc Format
- Format Conversion
The native Spark format preserves the full functionality of your Spark ML pipeline:
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)
# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)
predictions = spark_model.transform(test_data)
predictions.show()
Best for: Large-scale batch processing, existing Spark infrastructure
The PyFunc format enables deployment outside of Spark environments:
import pandas as pd
# Load as PyFunc model (no Spark session required)
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
# Use with pandas DataFrame
test_data = pd.DataFrame(
{"text": ["spark machine learning", "hadoop distributed computing"]}
)
# Predictions work seamlessly
predictions = pyfunc_model.predict(test_data)
print(predictions)
Best for: REST API deployment, edge computing, non-Spark environments
Automatic Conversion Process:
- Input Handling: PyFunc automatically converts pandas DataFrames to Spark DataFrames
- Spark Context: Creates a local Spark session if none exists
- Output Processing: Converts Spark ML vector outputs to arrays for pandas compatibility
- Performance Trade-offs: Initialization overhead vs deployment flexibility
When to Use Each Format:
- Native Spark: Large-scale batch processing, existing Spark infrastructure
- PyFunc: REST API deployment, edge computing, non-Spark environments
Pipeline Tracking and Managementβ
- Basic Pipeline
- Hyperparameter Tuning
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
# Load your dataset
data = spark.read.csv("path/to/dataset.csv", header=True, inferSchema=True)
with mlflow.start_run(run_name="Feature Pipeline"):
# Create feature engineering pipeline
feature_cols = ["age", "income", "credit_score"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="raw_features")
scaler = StandardScaler(inputCol="raw_features", outputCol="features")
rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=100)
# Create complete pipeline
pipeline = Pipeline(stages=[assembler, scaler, rf])
# Train pipeline
model = pipeline.fit(data)
# Log pipeline parameters
mlflow.log_params(
{
"num_features": len(feature_cols),
"num_trees": rf.getNumTrees(),
"max_depth": rf.getMaxDepth(),
}
)
# Log the complete pipeline
mlflow.spark.log_model(spark_model=model, artifact_path="feature_pipeline")
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
with mlflow.start_run(run_name="Hyperparameter Tuning"):
# Create base pipeline
lr = LogisticRegression(featuresCol="features", labelCol="label")
pipeline = Pipeline(stages=[assembler, scaler, lr])
# Create parameter grid
param_grid = (
ParamGridBuilder()
.addGrid(lr.regParam, [0.01, 0.1, 1.0])
.addGrid(lr.maxIter, [50, 100])
.build()
)
# Create cross-validator
evaluator = BinaryClassificationEvaluator(labelCol="label")
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=3,
)
# Fit cross-validator
cv_model = cv.fit(train_data)
# Log best parameters
best_model = cv_model.bestModel
best_lr_stage = best_model.stages[-1]
mlflow.log_params(
{
"best_regParam": best_lr_stage.getRegParam(),
"best_maxIter": best_lr_stage.getMaxIter(),
}
)
# Evaluate on test set
test_predictions = cv_model.transform(test_data)
test_auc = evaluator.evaluate(test_predictions)
mlflow.log_metric("test_auc", test_auc)
# Log the best model
mlflow.spark.log_model(
spark_model=cv_model.bestModel, artifact_path="best_cv_model"
)
Spark Datasource Autologgingβ
- Enable Autologging
- Setup Requirements
MLflow provides automatic logging of Spark datasource information:
import mlflow.spark
# Enable Spark datasource autologging
mlflow.spark.autolog()
# Now all datasource reads are automatically logged
with mlflow.start_run():
# These datasource operations are automatically tracked
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
processed_data = spark.read.csv(
"hdfs://cluster/processed/features.csv", header=True
)
# Train your model - datasource info is logged automatically
model = pipeline.fit(processed_data)
# Model training and datasource information both captured
mlflow.spark.log_model(model, artifact_path="model_with_datasource_tracking")
Requirements:
- Spark Version: Requires Spark 3.0 or above
- MLflow-Spark JAR: Must be included in Spark session configuration
- Environment: Not supported on Databricks shared/serverless clusters
JAR Configuration:
from pyspark.sql import SparkSession
# Configure Spark session with MLflow JAR
spark = (
SparkSession.builder.appName("MLflowAutologgingApp")
.config("spark.jars.packages", "org.mlflow:mlflow-spark_2.12:2.16.2")
.getOrCreate()
)
What Gets Logged:
- Path Information: Complete paths to data sources
- Format Details: File formats (parquet, delta, csv, etc.)
- Version Information: For versioned sources like Delta Lake
Model Signatures and Schema Managementβ
- Spark ML Vectors
- Manual Signatures
from mlflow.models import infer_signature
from pyspark.ml.linalg import Vectors
from pyspark.ml.functions import array_to_vector
# Create data with vector features
vector_data = spark.createDataFrame(
[([3.0, 4.0], 0), ([5.0, 6.0], 1)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")
# Train model
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)
# Get predictions for signature
predictions = model.transform(vector_data)
# Infer signature automatically
signature = infer_signature(vector_data, predictions.select("prediction"))
with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
input_example=vector_data.limit(2).toPandas(),
)
from mlflow.types import DataType, Schema, ColSpec
from mlflow.types.schema import SparkMLVector
from mlflow.models.signature import ModelSignature
# Create detailed model signature
input_schema = Schema(
[
ColSpec(DataType.string, "text"),
ColSpec(DataType.double, "numeric_feature"),
ColSpec(SparkMLVector(), "vector_feature"),
]
)
output_schema = Schema([ColSpec(DataType.double, "prediction")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# Log model with explicit signature
with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model, artifact_path="production_model", signature=signature
)
Cross-Platform Deploymentβ
- ONNX Conversion
- Model Registry
Convert Spark MLlib models to ONNX format for cross-platform deployment:
# Note: This requires onnxmltools (Spark ML support is experimental)
# pip install onnxmltools
import onnxmltools
with mlflow.start_run(run_name="ONNX Conversion"):
# Train your Spark ML model
model = pipeline.fit(training_data)
# Log original Spark model
spark_model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark_model"
)
try:
# Convert to ONNX using onnxmltools
# Note: Spark ML conversion is experimental and may have limitations
onnx_model = onnxmltools.convert_sparkml(
model, name="SparkMLPipeline", target_opset=None # Use default opset
)
# Save ONNX model as artifact
onnx_model_path = "model.onnx"
onnxmltools.utils.save_model(onnx_model, onnx_model_path)
mlflow.log_artifact(onnx_model_path)
mlflow.log_param("onnx_conversion_successful", True)
# Log ONNX model info
opset_version = onnx_model.opset_import[0].version
mlflow.log_param("onnx_opset_version", opset_version)
except Exception as e:
mlflow.log_param("onnx_conversion_error", str(e))
mlflow.log_param("onnx_conversion_successful", False)
# ONNX conversion for Spark ML is experimental and may not work
# for all model types. Consider using PyFunc format instead.
Note: Spark ML to ONNX conversion is experimental in onnxmltools and may not support all Spark ML operators. For production deployments, consider using MLflow's PyFunc format for broader compatibility.
from mlflow import MlflowClient
client = MlflowClient()
# Register model with production metadata
with mlflow.start_run():
# Train and evaluate model
model = pipeline.fit(train_data)
# Log model with registration
model_info = mlflow.spark.log_model(
spark_model=model,
artifact_path="production_candidate",
registered_model_name="CustomerSegmentationModel",
)
# Add production readiness tags
mlflow.set_tags(
{
"validation_passed": "true",
"deployment_target": "batch_scoring",
"model_type": "classification",
}
)
# Promote model through stages
model_version = client.get_latest_versions(
"CustomerSegmentationModel", stages=["None"]
)[0]
# Transition to Staging
client.transition_model_version_stage(
name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)
Production Deploymentβ
- Batch Inference
- Model Evaluation
def production_batch_scoring(model_uri, input_path, output_path):
"""Simple production batch scoring pipeline."""
with mlflow.start_run(run_name="Batch_Scoring"):
# Load production model
model = mlflow.spark.load_model(model_uri)
# Load input data
input_data = spark.read.parquet(input_path)
# Generate predictions
predictions = model.transform(input_data)
# Add metadata
predictions_with_metadata = predictions.withColumn(
"prediction_timestamp", F.current_timestamp()
)
# Write predictions
(predictions_with_metadata.write.mode("overwrite").parquet(output_path))
# Log job metrics
record_count = predictions.count()
mlflow.log_metrics({"records_processed": record_count, "job_success": 1})
return output_path
# Usage
production_batch_scoring(
model_uri="models:/CustomerSegmentationModel/Production",
input_path="s3://data-lake/daily-customers/",
output_path="s3://predictions/customer-segments/",
)
def evaluate_spark_model(model, test_data, model_name):
"""Evaluate Spark ML model with comprehensive metrics."""
with mlflow.start_run(run_name=f"Evaluation_{model_name}"):
# Generate predictions
predictions = model.transform(test_data)
# Calculate metrics based on problem type
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
MulticlassClassificationEvaluator,
)
# Binary classification metrics
binary_evaluator = BinaryClassificationEvaluator(labelCol="label")
auc = binary_evaluator.evaluate(predictions)
# Multiclass metrics
mc_evaluator = MulticlassClassificationEvaluator(labelCol="label")
accuracy = mc_evaluator.evaluate(
predictions, {mc_evaluator.metricName: "accuracy"}
)
f1 = mc_evaluator.evaluate(predictions, {mc_evaluator.metricName: "f1"})
# Log evaluation metrics
mlflow.log_metrics({"auc": auc, "accuracy": accuracy, "f1_score": f1})
# Feature importance (if available)
if hasattr(model.stages[-1], "featureImportances"):
feature_importance = model.stages[-1].featureImportances.toArray()
# Log top 5 feature importances
for i, importance in enumerate(feature_importance[:5]):
mlflow.log_metric(f"feature_importance_{i}", importance)
return {"auc": auc, "accuracy": accuracy, "f1_score": f1}
# Usage
evaluation_results = evaluate_spark_model(model, test_data, "RandomForest")
Error Handling and Best Practicesβ
- Robust Training
- Troubleshooting
- Performance Tips
def train_spark_model_with_error_handling(data_path, model_config):
"""Production-ready model training with error handling."""
with mlflow.start_run(run_name="Robust_Training"):
try:
# Load and validate data
data = spark.read.parquet(data_path)
record_count = data.count()
if record_count == 0:
raise ValueError("Input dataset is empty")
mlflow.log_metric("input_record_count", record_count)
# Create and train pipeline
pipeline = create_pipeline(model_config)
model = pipeline.fit(data)
# Validate model can make predictions
test_sample = data.limit(10)
predictions = model.transform(test_sample)
prediction_count = predictions.count()
if prediction_count != 10:
raise ValueError("Model validation failed")
# Log successful model
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="robust_model"
)
mlflow.log_param("training_status", "success")
return model_info
except Exception as e:
# Log error information
mlflow.log_param("training_status", "failed")
mlflow.log_param("error_message", str(e))
raise
def create_pipeline(config):
"""Create ML pipeline from configuration."""
# Simple pipeline creation logic
feature_cols = config.get("feature_columns", [])
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
algorithm = config.get("algorithm", "logistic_regression")
if algorithm == "logistic_regression":
classifier = LogisticRegression(featuresCol="features", labelCol="label")
elif algorithm == "random_forest":
classifier = RandomForestClassifier(featuresCol="features", labelCol="label")
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
return Pipeline(stages=[assembler, classifier])
Common Issues and Solutions:
Serialization Issues:
# Test model serialization
def test_model_serialization(model):
try:
with tempfile.TemporaryDirectory() as temp_dir:
mlflow.spark.save_model(model, temp_dir)
loaded_model = mlflow.spark.load_model(temp_dir)
return True
except Exception as e:
print(f"Serialization failed: {e}")
return False
Memory Issues:
# Configure Spark for large models
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
# Cache strategically
large_dataset.cache() # Only when reused multiple times
Efficient Logging:
def efficient_spark_ml_logging():
"""Configure efficient logging for Spark ML."""
with mlflow.start_run():
# Log parameters early (lightweight)
mlflow.log_params({"algorithm": "random_forest", "num_trees": 100})
# Train model
model = pipeline.fit(large_dataset)
# Log metrics before model (in case model logging fails)
metrics = {"accuracy": 0.95}
mlflow.log_metrics(metrics)
# Log model with minimal examples for large datasets
mlflow.spark.log_model(
spark_model=model,
artifact_path="efficient_model",
input_example=sample_data.limit(3).toPandas(), # Small sample only
)
Spark Configuration:
# Optimize Spark for MLflow operations
def configure_spark_for_mlflow():
"""Configure Spark session for optimal MLflow performance."""
spark_config = {
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true",
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
}
for key, value in spark_config.items():
spark.conf.set(key, value)
Conclusionβ
MLflow's Spark MLlib integration provides a comprehensive solution for tracking and managing distributed machine learning workflows. Whether you're building simple classification models or complex multi-stage pipelines, MLflow helps you maintain reproducibility and deploy models efficiently.
Key benefits of using MLflow with Spark MLlib include:
- Complete Pipeline Tracking: Automatic logging of multi-stage ML pipelines with all parameters and artifacts
- Flexible Deployment: Deploy as native Spark models for batch processing or PyFunc wrappers for universal compatibility
- Data Lineage: Automatic tracking of data sources through Spark datasource autologging
- Cross-Platform Support: ONNX conversion enables deployment across different environments
- Production Ready: Model registry integration and robust error handling for enterprise deployments
The patterns shown in this guide provide a solid foundation for building scalable, reproducible distributed machine learning systems. Start with basic model logging for immediate experiment tracking benefits, then adopt advanced features like datasource autologging and model registry integration as your needs grow.