Skip to main content

Spark MLlib with MLflow

In this comprehensive guide, we'll walk you through how to use Spark MLlib with MLflow for experiment tracking, model management, and production deployment. We'll cover basic model logging, pipeline tracking, and deployment patterns that will get you productive quickly with distributed machine learning.

Quick Start with Basic Model Logging​

The simplest way to get started is by logging your Spark MLlib models directly to MLflow:

import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()

# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)

# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)

# Log the entire pipeline
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark-pipeline"
)

# Log parameters manually
mlflow.log_params(
{
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
}
)

print(f"Model logged with URI: {model_info.model_uri}")

This simple example automatically logs:

  • The complete Spark ML pipeline with all stages
  • Model parameters from each pipeline stage
  • The trained model in both Spark native and PyFunc formats

Model Formats and Loading​

The native Spark format preserves the full functionality of your Spark ML pipeline:

# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)

# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)

predictions = spark_model.transform(test_data)
predictions.show()

Best for: Large-scale batch processing, existing Spark infrastructure

Pipeline Tracking and Management​

from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier

# Load your dataset
data = spark.read.csv("path/to/dataset.csv", header=True, inferSchema=True)

with mlflow.start_run(run_name="Feature Pipeline"):
# Create feature engineering pipeline
feature_cols = ["age", "income", "credit_score"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="raw_features")
scaler = StandardScaler(inputCol="raw_features", outputCol="features")
rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=100)

# Create complete pipeline
pipeline = Pipeline(stages=[assembler, scaler, rf])

# Train pipeline
model = pipeline.fit(data)

# Log pipeline parameters
mlflow.log_params(
{
"num_features": len(feature_cols),
"num_trees": rf.getNumTrees(),
"max_depth": rf.getMaxDepth(),
}
)

# Log the complete pipeline
mlflow.spark.log_model(spark_model=model, artifact_path="feature_pipeline")

Spark Datasource Autologging​

MLflow provides automatic logging of Spark datasource information:

import mlflow.spark

# Enable Spark datasource autologging
mlflow.spark.autolog()

# Now all datasource reads are automatically logged
with mlflow.start_run():
# These datasource operations are automatically tracked
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
processed_data = spark.read.csv(
"hdfs://cluster/processed/features.csv", header=True
)

# Train your model - datasource info is logged automatically
model = pipeline.fit(processed_data)

# Model training and datasource information both captured
mlflow.spark.log_model(model, artifact_path="model_with_datasource_tracking")

Model Signatures and Schema Management​

from mlflow.models import infer_signature
from pyspark.ml.linalg import Vectors
from pyspark.ml.functions import array_to_vector

# Create data with vector features
vector_data = spark.createDataFrame(
[([3.0, 4.0], 0), ([5.0, 6.0], 1)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")

# Train model
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)

# Get predictions for signature
predictions = model.transform(vector_data)

# Infer signature automatically
signature = infer_signature(vector_data, predictions.select("prediction"))

with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
input_example=vector_data.limit(2).toPandas(),
)

Cross-Platform Deployment​

Convert Spark MLlib models to ONNX format for cross-platform deployment:

# Note: This requires onnxmltools (Spark ML support is experimental)
# pip install onnxmltools

import onnxmltools

with mlflow.start_run(run_name="ONNX Conversion"):
# Train your Spark ML model
model = pipeline.fit(training_data)

# Log original Spark model
spark_model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark_model"
)

try:
# Convert to ONNX using onnxmltools
# Note: Spark ML conversion is experimental and may have limitations
onnx_model = onnxmltools.convert_sparkml(
model, name="SparkMLPipeline", target_opset=None # Use default opset
)

# Save ONNX model as artifact
onnx_model_path = "model.onnx"
onnxmltools.utils.save_model(onnx_model, onnx_model_path)

mlflow.log_artifact(onnx_model_path)
mlflow.log_param("onnx_conversion_successful", True)

# Log ONNX model info
opset_version = onnx_model.opset_import[0].version
mlflow.log_param("onnx_opset_version", opset_version)

except Exception as e:
mlflow.log_param("onnx_conversion_error", str(e))
mlflow.log_param("onnx_conversion_successful", False)

# ONNX conversion for Spark ML is experimental and may not work
# for all model types. Consider using PyFunc format instead.

Note: Spark ML to ONNX conversion is experimental in onnxmltools and may not support all Spark ML operators. For production deployments, consider using MLflow's PyFunc format for broader compatibility.

Production Deployment​

def production_batch_scoring(model_uri, input_path, output_path):
"""Simple production batch scoring pipeline."""

with mlflow.start_run(run_name="Batch_Scoring"):
# Load production model
model = mlflow.spark.load_model(model_uri)

# Load input data
input_data = spark.read.parquet(input_path)

# Generate predictions
predictions = model.transform(input_data)

# Add metadata
predictions_with_metadata = predictions.withColumn(
"prediction_timestamp", F.current_timestamp()
)

# Write predictions
(predictions_with_metadata.write.mode("overwrite").parquet(output_path))

# Log job metrics
record_count = predictions.count()
mlflow.log_metrics({"records_processed": record_count, "job_success": 1})

return output_path


# Usage
production_batch_scoring(
model_uri="models:/CustomerSegmentationModel/Production",
input_path="s3://data-lake/daily-customers/",
output_path="s3://predictions/customer-segments/",
)

Error Handling and Best Practices​

def train_spark_model_with_error_handling(data_path, model_config):
"""Production-ready model training with error handling."""

with mlflow.start_run(run_name="Robust_Training"):
try:
# Load and validate data
data = spark.read.parquet(data_path)
record_count = data.count()

if record_count == 0:
raise ValueError("Input dataset is empty")

mlflow.log_metric("input_record_count", record_count)

# Create and train pipeline
pipeline = create_pipeline(model_config)
model = pipeline.fit(data)

# Validate model can make predictions
test_sample = data.limit(10)
predictions = model.transform(test_sample)
prediction_count = predictions.count()

if prediction_count != 10:
raise ValueError("Model validation failed")

# Log successful model
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="robust_model"
)

mlflow.log_param("training_status", "success")
return model_info

except Exception as e:
# Log error information
mlflow.log_param("training_status", "failed")
mlflow.log_param("error_message", str(e))
raise


def create_pipeline(config):
"""Create ML pipeline from configuration."""

# Simple pipeline creation logic
feature_cols = config.get("feature_columns", [])
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

algorithm = config.get("algorithm", "logistic_regression")
if algorithm == "logistic_regression":
classifier = LogisticRegression(featuresCol="features", labelCol="label")
elif algorithm == "random_forest":
classifier = RandomForestClassifier(featuresCol="features", labelCol="label")
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

return Pipeline(stages=[assembler, classifier])

Conclusion​

MLflow's Spark MLlib integration provides a comprehensive solution for tracking and managing distributed machine learning workflows. Whether you're building simple classification models or complex multi-stage pipelines, MLflow helps you maintain reproducibility and deploy models efficiently.

Key benefits of using MLflow with Spark MLlib include:

  • Complete Pipeline Tracking: Automatic logging of multi-stage ML pipelines with all parameters and artifacts
  • Flexible Deployment: Deploy as native Spark models for batch processing or PyFunc wrappers for universal compatibility
  • Data Lineage: Automatic tracking of data sources through Spark datasource autologging
  • Cross-Platform Support: ONNX conversion enables deployment across different environments
  • Production Ready: Model registry integration and robust error handling for enterprise deployments

The patterns shown in this guide provide a solid foundation for building scalable, reproducible distributed machine learning systems. Start with basic model logging for immediate experiment tracking benefits, then adopt advanced features like datasource autologging and model registry integration as your needs grow.