Skip to main content

Custom Serving Applications

MLflow's custom serving applications allow you to build sophisticated model serving solutions that go beyond simple prediction endpoints. Using the PyFunc framework, you can create custom applications with complex preprocessing, postprocessing, multi-model inference, and business logic integration.

Overview​

Custom serving applications in MLflow are built using the mlflow.pyfunc.PythonModel class, which provides a flexible framework for creating deployable models with custom logic. This approach is ideal when you need to:

  • πŸ”„ Implement advanced preprocessing and postprocessing logic
  • 🧠 Combine multiple models within a single serving pipeline
  • βœ… Apply business rules and custom validation checks
  • πŸ”£ Support diverse input and output data formats
  • 🌐 Integrate seamlessly with external systems or databases

Custom PyFunc Model​

Custom Model​

Here's an example of a custom PyFunc model:

import mlflow
import pandas as pd
import json
from typing import Dict, List, Any
import openai # or any other LLM client


class CustomLLMModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load LLM configuration and initialize client"""
# Load model configuration from artifacts
config_path = context.artifacts.get("config", "config.json")
with open(config_path, "r") as f:
self.config = json.load(f)

# Initialize LLM client
self.client = openai.OpenAI(api_key=self.config["api_key"])
self.model_name = self.config["model_name"]
self.system_prompt = self.config.get(
"system_prompt", "You are a helpful assistant."
)

def predict(self, context, model_input):
"""Core LLM prediction logic"""
if isinstance(model_input, pd.DataFrame):
# Handle DataFrame input with prompts
responses = []
for _, row in model_input.iterrows():
user_prompt = row.get("prompt", row.get("input", ""))
processed_prompt = self._preprocess_prompt(user_prompt)
response = self._generate_response(processed_prompt)
post_processed = self._postprocess_response(response)
responses.append(post_processed)
return pd.DataFrame({"response": responses})
elif isinstance(model_input, dict):
# Handle single prompt
user_prompt = model_input.get("prompt", model_input.get("input", ""))
processed_prompt = self._preprocess_prompt(user_prompt)
response = self._generate_response(processed_prompt)
return self._postprocess_response(response)
else:
# Handle string input
processed_prompt = self._preprocess_prompt(str(model_input))
response = self._generate_response(processed_prompt)
return self._postprocess_response(response)

def _preprocess_prompt(self, prompt: str) -> str:
"""Custom prompt preprocessing logic"""
# Example: Add context, format prompt, apply templates
template = self.config.get("prompt_template", "{prompt}")
return template.format(prompt=prompt)

def _generate_response(self, prompt: str) -> str:
"""Core LLM inference"""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
],
temperature=self.config.get("temperature", 0.7),
max_tokens=self.config.get("max_tokens", 1000),
)
return response.choices[0].message.content
except Exception as e:
return f"Error generating response: {str(e)}"

def _postprocess_response(self, response: str) -> str:
"""Custom response postprocessing logic"""
# Example: format output, apply filters, extract specific parts
if self.config.get("strip_whitespace", True):
response = response.strip()

max_length = self.config.get("max_response_length")
if max_length and len(response) > max_length:
response = response[:max_length] + "..."

return response


# Example configuration
config = {
"api_key": "your-api-key",
"model_name": "gpt-4",
"system_prompt": "You are an expert data analyst. Provide clear, concise answers.",
"temperature": 0.3,
"max_tokens": 500,
"prompt_template": "Context: Data Analysis Task\n\nQuestion: {prompt}\n\nAnswer:",
"strip_whitespace": True,
"max_response_length": 1000,
}

# Save configuration
with open("config.json", "w") as f:
json.dump(config, f)

# Log the model
with mlflow.start_run():
# Log configuration as artifact
mlflow.log_artifact("config.json")

# Create input example
input_example = pd.DataFrame(
{"prompt": ["What is machine learning?", "Explain neural networks"]}
)

model_info = mlflow.pyfunc.log_model(
name="custom_llm_model",
python_model=CustomLLMModel(),
artifacts={"config": "config.json"},
input_example=input_example,
)

Multi-Model Ensemble​

Create a custom application that combines multiple LLMs with different strengths:

import mlflow
import mlflow.pyfunc
import pandas as pd
import json
import openai
import anthropic
from typing import List, Dict, Any


class MultiLLMEnsemble(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""Load multiple LLM configurations from artifacts"""
# Load ensemble configuration
config_path = context.artifacts["ensemble_config"]
with open(config_path, "r") as f:
self.config = json.load(f)

# Initialize multiple LLM clients
self.llm_clients = {}

# OpenAI client
if "openai" in self.config["models"]:
self.llm_clients["openai"] = openai.OpenAI(
api_key=self.config["models"]["openai"]["api_key"]
)

# Anthropic client
if "anthropic" in self.config["models"]:
self.llm_clients["anthropic"] = anthropic.Anthropic(
api_key=self.config["models"]["anthropic"]["api_key"]
)

# Add other LLM clients as needed

self.voting_strategy = self.config.get("voting_strategy", "weighted_average")
self.model_weights = self.config.get("model_weights", {})

def predict(self, context, model_input):
"""Ensemble prediction with multiple LLMs"""
if isinstance(model_input, pd.DataFrame):
responses = []
for _, row in model_input.iterrows():
prompt = row.get("prompt", row.get("input", ""))
task_type = row.get("task_type", "general")
ensemble_response = self._generate_ensemble_response(prompt, task_type)
responses.append(ensemble_response)
return pd.DataFrame({"response": responses})
else:
prompt = model_input.get("prompt", str(model_input))
task_type = (
model_input.get("task_type", "general")
if isinstance(model_input, dict)
else "general"
)
return self._generate_ensemble_response(prompt, task_type)

def _generate_ensemble_response(
self, prompt: str, task_type: str = "general"
) -> str:
"""Generate responses from multiple LLMs and combine them"""
responses = {}

# Get task-specific model configuration
task_config = self.config.get("task_routing", {}).get(task_type, {})
active_models = task_config.get("models", list(self.llm_clients.keys()))

# Generate responses from each active model
for model_name in active_models:
if model_name in self.llm_clients:
response = self._generate_single_response(model_name, prompt, task_type)
responses[model_name] = response

# Combine responses based on voting strategy
return self._combine_responses(responses, task_type)

def _generate_single_response(
self, model_name: str, prompt: str, task_type: str
) -> str:
"""Generate response from a single LLM"""
model_config = self.config["models"][model_name]

try:
if model_name == "openai":
response = self.llm_clients["openai"].chat.completions.create(
model=model_config["model_name"],
messages=[
{
"role": "system",
"content": model_config.get("system_prompt", ""),
},
{"role": "user", "content": prompt},
],
temperature=model_config.get("temperature", 0.7),
max_tokens=model_config.get("max_tokens", 1000),
)
return response.choices[0].message.content

elif model_name == "anthropic":
response = self.llm_clients["anthropic"].messages.create(
model=model_config["model_name"],
max_tokens=model_config.get("max_tokens", 1000),
temperature=model_config.get("temperature", 0.7),
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text

# Add other LLM implementations here

except Exception as e:
return f"Error from {model_name}: {str(e)}"

def _combine_responses(self, responses: Dict[str, str], task_type: str) -> str:
"""Combine multiple LLM responses using specified strategy"""
if self.voting_strategy == "best_for_task":
# Route to best model for specific task type
task_config = self.config.get("task_routing", {}).get(task_type, {})
preferred_model = task_config.get("preferred_model")
if preferred_model and preferred_model in responses:
return responses[preferred_model]

elif self.voting_strategy == "consensus":
# Return response if multiple models agree (simplified)
response_list = list(responses.values())
if len(set(response_list)) == 1:
return response_list[0]
else:
# If no consensus, return the longest response
return max(response_list, key=len)

elif self.voting_strategy == "weighted_combination":
# Combine responses with weights (simplified text combination)
combined_response = "Combined insights:\n\n"
for model_name, response in responses.items():
weight = self.model_weights.get(model_name, 1.0)
combined_response += (
f"[{model_name.upper()} - Weight: {weight}]: {response}\n\n"
)
return combined_response

# Default: return first available response
return list(responses.values())[0] if responses else "No response generated"


# Example ensemble configuration
ensemble_config = {
"voting_strategy": "best_for_task",
"models": {
"openai": {
"api_key": "your-openai-key",
"model_name": "gpt-4",
"system_prompt": "You are a helpful assistant specialized in technical analysis.",
"temperature": 0.3,
"max_tokens": 800,
},
"anthropic": {
"api_key": "your-anthropic-key",
"model_name": "claude-3-sonnet-20240229",
"temperature": 0.5,
"max_tokens": 1000,
},
},
"task_routing": {
"code_analysis": {"models": ["openai"], "preferred_model": "openai"},
"creative_writing": {"models": ["anthropic"], "preferred_model": "anthropic"},
"general": {"models": ["openai", "anthropic"], "preferred_model": "openai"},
},
"model_weights": {"openai": 0.6, "anthropic": 0.4},
}

# Save configuration
with open("ensemble_config.json", "w") as f:
json.dump(ensemble_config, f)

# Log the ensemble model
with mlflow.start_run():
# Log configuration as artifact
mlflow.log_artifact("ensemble_config.json")

# Create input example
input_example = pd.DataFrame(
{
"prompt": ["Explain quantum computing", "Write a creative story about AI"],
"task_type": ["general", "creative_writing"],
}
)

mlflow.pyfunc.log_model(
name="multi_llm_ensemble",
python_model=MultiLLMEnsemble(),
artifacts={"ensemble_config": "ensemble_config.json"},
input_example=input_example,
)

Serving Custom Applications​

Local Serving​

Once you've created and saved your custom application, serve it locally:

# Serve from saved model path
mlflow models serve -m ./path/to/custom/model -p 5000

# Serve from Model Registry
mlflow models serve -m "models:/CustomApp/Production" -p 5000

Docker Deployment​

Build a Docker image for your custom application:

# Build Docker image
mlflow models build-docker -m ./path/to/custom/model -n custom-app

# Run the container
docker run -p 5000:8080 custom-app

Testing Custom Applications​

Test your custom serving application:

import requests
import pandas as pd
import json

# Prepare test data
test_data = pd.DataFrame(
{
"feature1": [1.0, 2.0, 3.0],
"feature2": [0.5, 1.5, 2.5],
"customer_value": [5000, 15000, 3000],
}
)

# Convert to the expected input format
input_data = {"inputs": test_data.to_dict("records")}

# Make prediction request
response = requests.post(
"http://localhost:5000/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(input_data),
)

print("Response:", response.json())

Best Practices for Custom Applications​

Error Handling​

Implement comprehensive error handling:

def predict(self, context, model_input):
try:
# Validate input
self._validate_input(model_input)

# Process and predict
result = self._process_prediction(model_input)

return result

except ValueError as e:
# Handle validation errors
return {"error": f"Validation error: {str(e)}"}
except Exception as e:
# Handle unexpected errors
return {"error": f"Prediction failed: {str(e)}"}

Performance Optimization​

  • πŸ’€ Lazy Loading: Defer loading large artifacts until they’re needed
  • πŸ—‚οΈ Caching: Store and reuse results of frequent computations
  • πŸ“¦ Batch Processing: Handle multiple inputs in a single, efficient operation
  • 🧹 Memory Management: Release unused resources after each request or task

Testing and Validation​

  • πŸ§ͺ Unit Testing: Test individual components of your custom model in isolation
  • πŸ”— Integration Testing: Verify the full prediction pipeline end-to-end
  • βœ… Output Validation: Ensure correct output formats and robust error handling
  • πŸš€ Performance Testing: Evaluate scalability using realistic data volumes and loads

Documentation​

Document your custom applications thoroughly:

  • πŸ“₯ Input/Output Specifications: Clearly define expected input formats and output structures
  • βš™οΈ Business Logic: Document the core logic and decision-making rules
  • ⚑ Performance Characteristics: Describe expected throughput, latency, and resource usage
  • ❗ Error Handling: Specify how errors are detected, managed, and communicated

Integration with Databricks​

In Databricks Managed MLflow, custom applications can take advantage of additional features:

  • ☁️ Serverless Compute: Automatic scaling based on demand
  • πŸ” Security Integration: Built-in authentication and authorization
  • πŸ“ˆ Monitoring: Advanced metrics and logging capabilities
  • πŸ—‚οΈ Version Management: Seamless model version management with Unity Catalog

Note that the creation and management of serving endpoints is handled differently in Databricks compared to MLflow OSS, with additional UI and API capabilities for enterprise deployment.

Troubleshooting​

Common Issues​

  1. Import Errors: Ensure all dependencies are specified in the conda environment
  2. Artifact Loading: Verify artifact paths are correct and accessible
  3. Memory Issues: Monitor memory usage with large models or datasets
  4. Serialization: Use models-from-code feature when logging models that are not picklable

Debugging Tips​

  • 🧾 Enable tracing to track execution flow
  • πŸ§ͺ Test components individually before integration
  • πŸ“Š Use small test datasets for initial validation
  • πŸ–₯️ Monitor resource usage during development

Custom serving applications provide the flexibility to build production-ready ML systems that integrate seamlessly with your business requirements while maintaining the reliability and scalability of MLflow's serving infrastructure.