Skip to main content

Keras within MLflow

In this guide we will walk you through how to use Keras with MLflow. We will demonstrate how to track your Keras experiments and log your Keras models to MLflow using both autologging and manual logging approaches.

Setting Up Keras Backend​

Keras 3.0 is inherently multi-backend, supporting TensorFlow, JAX, and PyTorch. You must set the backend environment variable before importing Keras:

import os

# You can use 'tensorflow', 'torch', or 'jax' as backend
# Make sure to set the environment variable before importing Keras
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np
import mlflow
backend-selection

The backend must be set before importing Keras. Once Keras is imported, the backend cannot be changed within the same Python session.

This multi-backend architecture means your MLflow tracking code works consistently regardless of which backend you choose, giving you the flexibility to optimize for your specific use case without changing your experiment tracking setup.

Logging Keras Experiments to MLflow​

Autologging Keras Experiments​

MLflow provides seamless autologging integration with Keras/TensorFlow. To enable automatic logging of metrics, parameters, and models, simply call mlflow.tensorflow.autolog() or mlflow.autolog() before your training code.

import mlflow
import mlflow.tensorflow

# Enable autologging for TensorFlow/Keras
mlflow.tensorflow.autolog()

# Your existing Keras training code works unchanged
model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10)
version-support

Only versions of tensorflow>=2.3 are supported. The autologging feature captures metrics from both tf.keras and tf.keras.callbacks.EarlyStopping.

What Gets Automatically Logged​

Autologging captures comprehensive information about your Keras training:

FrameworkMetricsParametersArtifacts
tf.kerasTraining loss; validation loss; user-specified metricsfit() parameters; optimizer name; learning rate; epsilonModel summary on training start; MLflow Model (Keras model); TensorBoard logs on training end
tf.keras.callbacks.EarlyStoppingMetrics from EarlyStopping callbacks: stopped_epoch, restored_epoch, restore_best_weight, etcfit() parameters from EarlyStopping: min_delta, patience, baseline, restore_best_weights, etc--

Automatic Run Management​

MLflow intelligently manages runs when using autologging:

  • No Active Run: If no active run exists when autolog() captures data, MLflow automatically creates a run and ends it once training completes via tf.keras.fit()
  • Existing Run: If a run already exists, MLflow logs to that run but does not automatically end it - you must manually stop the run if needed

Manually Logging Keras Experiments​

For more control over what gets logged, you can manually instrument your Keras training code using MLflow's logging APIs:

Best Practices for Manual Logging​

When manually logging Keras experiments, follow these best practices:

  • Log training parameters at the beginning of training via mlflow.log_params(), including learning rate, batch size, epochs, etc.
  • Log model architecture at the beginning of training via mlflow.log_artifact(). You can save the model summary as a text file
  • Log training and validation metrics inside your training loop or callbacks via mlflow.log_metric()
  • Log your trained model to MLflow via mlflow.keras.log_model() at the end of training
  • [Optional] Log model checkpoints during training via mlflow.log_artifact() to preserve intermediate training states

Complete Manual Logging Example​

Here's an end-to-end example of manually logging a Keras experiment:

import mlflow
import mlflow.keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Load and prepare data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


# Define model architecture
def create_model():
model = keras.Sequential(
[
layers.Dense(512, activation="relu", input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(256, activation="relu"),
layers.Dropout(0.2),
layers.Dense(10, activation="softmax"),
]
)
return model


# Training parameters
params = {
"epochs": 10,
"batch_size": 128,
"learning_rate": 0.001,
"optimizer": "adam",
"loss_function": "categorical_crossentropy",
"dropout_rate": 0.2,
"hidden_units": [512, 256],
}

with mlflow.start_run():
# Log training parameters
mlflow.log_params(params)

# Create and compile model
model = create_model()
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=params["learning_rate"]),
loss=params["loss_function"],
metrics=["accuracy"],
)

# Log model architecture
with open("model_summary.txt", "w") as f:
model.summary(print_fn=lambda x: f.write(x + "\n"))
mlflow.log_artifact("model_summary.txt")

# Custom callback for logging metrics
class MLflowCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if logs:
mlflow.log_metrics(
{
"train_loss": logs.get("loss"),
"train_accuracy": logs.get("accuracy"),
"val_loss": logs.get("val_loss"),
"val_accuracy": logs.get("val_accuracy"),
},
step=epoch,
)

# Train model with custom callback
history = model.fit(
x_train,
y_train,
batch_size=params["batch_size"],
epochs=params["epochs"],
validation_data=(x_test, y_test),
callbacks=[MLflowCallback()],
verbose=1,
)

# Evaluate model
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
mlflow.log_metrics({"test_loss": test_loss, "test_accuracy": test_accuracy})

# Log the trained model
mlflow.keras.log_model(model, name="model")

print(f"Test accuracy: {test_accuracy:.4f}")

If you run this code with a local MLflow server, you'll see comprehensive tracking in the MLflow UI.

Using MLflow's Keras Callback​

MLflow provides a built-in callback for Keras that simplifies experiment tracking. The mlflow.keras.MlflowCallback() integrates seamlessly with your Keras training loop:

import mlflow
import mlflow.keras
from mlflow.keras import MlflowCallback

# Create model and prepare data (same as above)
model = create_model()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

with mlflow.start_run() as run:
# Use MLflow's built-in callback
mlflow_callback = MlflowCallback(run)

history = model.fit(
x_train,
y_train,
batch_size=128,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[mlflow_callback],
verbose=1,
)

Advanced Callback Usage​

The MlflowCallback supports various configuration options:

# Log metrics every 5 batches instead of every epoch
mlflow_callback = MlflowCallback(run, log_every_epoch=False, log_every_n_steps=5)


# Custom callback subclass for specialized logging
class CustomMlflowCallback(MlflowCallback):
def on_epoch_end(self, epoch, logs=None):
# Call parent method
super().on_epoch_end(epoch, logs)

# Add custom logging
if logs and epoch % 5 == 0: # Log every 5 epochs
mlflow.log_metric(
"learning_rate", self.model.optimizer.learning_rate.numpy()
)

def on_train_end(self, logs=None):
# Log final model weights distribution
weights = self.model.get_weights()
avg_weight = np.mean([np.mean(w) for w in weights])
mlflow.log_metric("avg_final_weight", avg_weight)

Saving Your Keras Model to MLflow​

Basic Model Saving​

Save your trained Keras model using mlflow.keras.log_model():

import mlflow
import mlflow.keras
import numpy as np

# Train your model
model = create_model()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.fit(x_train, y_train, epochs=5)

model_info = mlflow.keras.log_model(model, name="model")

# Load and use the model
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Make predictions
predictions = loaded_model.predict(x_test[:5])
print("Predictions:", predictions)

Model Signatures​

A model signature describes a model's input and output schema. While not required, it's a best practice for better model understanding and validation:

import mlflow
from mlflow.models import infer_signature
import numpy as np

# Prepare sample data for signature inference
sample_input = x_test[:100]
sample_predictions = model.predict(sample_input)

# Infer signature from sample data
signature = infer_signature(sample_input, sample_predictions)

# Log model with signature
model_info = mlflow.keras.log_model(model, name="model", signature=signature)

You can also manually create signatures for more control:

from mlflow.types import Schema, TensorSpec
from mlflow.models import ModelSignature
import numpy as np

# Define input and output schemas
input_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 784))])
output_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 10))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

model_info = mlflow.keras.log_model(model, name="model", signature=signature)

Multi-Backend Support with Keras 3.0​

As mentioned at the start of this guide, Keras 3.0's multi-backend support is one of its most powerful features. MLflow's tracking works seamlessly across all supported backends:

import os
import mlflow

# Switch backends easily - MLflow tracking code remains identical
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"

import keras
import mlflow.keras

# Enable autologging (works with any backend)
mlflow.tensorflow.autolog()

# Your training code is backend-agnostic
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(10, activation="softmax"),
]
)

model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

with mlflow.start_run():
model.fit(x_train, y_train, epochs=5, validation_split=0.2)

This consistency means you can:

  • Experiment with different backends without changing your MLflow tracking code
  • Optimize performance by choosing the best backend for your hardware (JAX for TPUs, PyTorch for research flexibility, TensorFlow for production)
  • Maintain reproducibility across different computational environments

Advanced Features​

Hyperparameter Tuning with Keras and MLflow​

Combine Keras with hyperparameter tuning libraries while tracking everything in MLflow:

import mlflow
import optuna
from sklearn.model_selection import train_test_split


def objective(trial):
# Suggest hyperparameters
lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
hidden_units = trial.suggest_int("hidden_units", 64, 512)
dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)

with mlflow.start_run(nested=True):
# Log trial parameters
mlflow.log_params(
{
"learning_rate": lr,
"batch_size": batch_size,
"hidden_units": hidden_units,
"dropout_rate": dropout_rate,
}
)

# Create model with suggested parameters
model = keras.Sequential(
[
keras.layers.Dense(hidden_units, activation="relu", input_shape=(784,)),
keras.layers.Dropout(dropout_rate),
keras.layers.Dense(10, activation="softmax"),
]
)

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr),
loss="categorical_crossentropy",
metrics=["accuracy"],
)

# Train model
history = model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=10,
validation_split=0.2,
verbose=0,
)

# Get validation accuracy
val_accuracy = max(history.history["val_accuracy"])
mlflow.log_metric("val_accuracy", val_accuracy)

return val_accuracy


# Run hyperparameter optimization
with mlflow.start_run():
mlflow.set_tag("optimization", "optuna")
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)

# Log best parameters
mlflow.log_params(study.best_params)
mlflow.log_metric("best_val_accuracy", study.best_value)

Custom Metrics and Artifacts​

Log custom visualizations and metrics specific to your use case:

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report


def log_training_plots(history, run_id):
"""Log training history plots to MLflow."""

# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history.history["loss"], label="Training Loss")
ax1.plot(history.history["val_loss"], label="Validation Loss")
ax1.set_title("Model Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()

ax2.plot(history.history["accuracy"], label="Training Accuracy")
ax2.plot(history.history["val_accuracy"], label="Validation Accuracy")
ax2.set_title("Model Accuracy")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.legend()

plt.tight_layout()
plt.savefig("training_history.png", dpi=300, bbox_inches="tight")
mlflow.log_artifact("training_history.png")
plt.close()


def log_evaluation_metrics(model, x_test, y_test, class_names):
"""Log comprehensive evaluation metrics."""

# Get predictions
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test, axis=1)

# Confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
mlflow.log_artifact("confusion_matrix.png")
plt.close()

# Classification report
report = classification_report(
y_true_classes, y_pred_classes, target_names=class_names, output_dict=True
)

# Log per-class metrics
for class_name in class_names:
if class_name in report:
mlflow.log_metrics(
{
f"{class_name}_precision": report[class_name]["precision"],
f"{class_name}_recall": report[class_name]["recall"],
f"{class_name}_f1": report[class_name]["f1-score"],
}
)


# Usage example
with mlflow.start_run():
# Train model
history = model.fit(
x_train, y_train, validation_data=(x_test, y_test), epochs=10, verbose=1
)

# Log comprehensive results
log_training_plots(history, mlflow.active_run().info.run_id)
log_evaluation_metrics(
model, x_test, y_test, class_names=[str(i) for i in range(10)]
)

Conclusion​

MLflow's integration with Keras provides a comprehensive solution for experiment tracking and model management in deep learning workflows. Whether you choose autologging for simplicity or manual logging for fine-grained control, MLflow captures all the essential information needed for reproducible machine learning research and production deployment.

Key benefits of using MLflow with Keras include:

  • Seamless Integration: One-line autologging setup with comprehensive tracking
  • Multi-Backend Support: Consistent tracking across TensorFlow, JAX, and PyTorch backends
  • Flexible Logging: Choose between automatic and manual logging approaches
  • Production Ready: Built-in model serving and deployment capabilities
  • Collaborative Development: Share experiments and models through MLflow's intuitive UI

Whether you're conducting research experiments or building production ML systems, the MLflow-Keras integration provides the foundation for organized, reproducible, and scalable deep learning workflows.