Keras within MLflow
In this guide we will walk you through how to use Keras with MLflow. We will demonstrate how to track your Keras experiments and log your Keras models to MLflow using both autologging and manual logging approaches.
Setting Up Keras Backendβ
Keras 3.0 is inherently multi-backend, supporting TensorFlow, JAX, and PyTorch. You must set the backend environment variable before importing Keras:
import os
# You can use 'tensorflow', 'torch', or 'jax' as backend
# Make sure to set the environment variable before importing Keras
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import mlflow
The backend must be set before importing Keras. Once Keras is imported, the backend cannot be changed within the same Python session.
This multi-backend architecture means your MLflow tracking code works consistently regardless of which backend you choose, giving you the flexibility to optimize for your specific use case without changing your experiment tracking setup.
Logging Keras Experiments to MLflowβ
Autologging Keras Experimentsβ
MLflow provides seamless autologging integration with Keras/TensorFlow. To enable automatic logging of metrics, parameters, and models, simply call mlflow.tensorflow.autolog()
or mlflow.autolog()
before your training code.
import mlflow
import mlflow.tensorflow
# Enable autologging for TensorFlow/Keras
mlflow.tensorflow.autolog()
# Your existing Keras training code works unchanged
model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10)
Only versions of tensorflow>=2.3
are supported. The autologging feature captures metrics from both tf.keras
and tf.keras.callbacks.EarlyStopping
.
What Gets Automatically Loggedβ
Autologging captures comprehensive information about your Keras training:
Framework | Metrics | Parameters | Artifacts |
---|---|---|---|
tf.keras | Training loss; validation loss; user-specified metrics | fit() parameters; optimizer name; learning rate; epsilon | Model summary on training start; MLflow Model (Keras model); TensorBoard logs on training end |
tf.keras.callbacks.EarlyStopping | Metrics from EarlyStopping callbacks: stopped_epoch , restored_epoch , restore_best_weight , etc | fit() parameters from EarlyStopping: min_delta , patience , baseline , restore_best_weights , etc | -- |
Automatic Run Managementβ
MLflow intelligently manages runs when using autologging:
- No Active Run: If no active run exists when
autolog()
captures data, MLflow automatically creates a run and ends it once training completes viatf.keras.fit()
- Existing Run: If a run already exists, MLflow logs to that run but does not automatically end it - you must manually stop the run if needed
Manually Logging Keras Experimentsβ
For more control over what gets logged, you can manually instrument your Keras training code using MLflow's logging APIs:
mlflow.log_metric()
/mlflow.log_metrics()
: Log metrics such as accuracy and loss during trainingmlflow.log_param()
/mlflow.log_params()
: Log parameters such as learning rate and batch sizemlflow.keras.log_model()
: Save your Keras model to MLflowmlflow.log_artifact()
: Log artifacts such as model checkpoints and plots
Best Practices for Manual Loggingβ
When manually logging Keras experiments, follow these best practices:
- Log training parameters at the beginning of training via
mlflow.log_params()
, including learning rate, batch size, epochs, etc. - Log model architecture at the beginning of training via
mlflow.log_artifact()
. You can save the model summary as a text file - Log training and validation metrics inside your training loop or callbacks via
mlflow.log_metric()
- Log your trained model to MLflow via
mlflow.keras.log_model()
at the end of training - [Optional] Log model checkpoints during training via
mlflow.log_artifact()
to preserve intermediate training states
Complete Manual Logging Exampleβ
Here's an end-to-end example of manually logging a Keras experiment:
import mlflow
import mlflow.keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Load and prepare data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
# Define model architecture
def create_model():
model = keras.Sequential(
[
layers.Dense(512, activation="relu", input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(256, activation="relu"),
layers.Dropout(0.2),
layers.Dense(10, activation="softmax"),
]
)
return model
# Training parameters
params = {
"epochs": 10,
"batch_size": 128,
"learning_rate": 0.001,
"optimizer": "adam",
"loss_function": "categorical_crossentropy",
"dropout_rate": 0.2,
"hidden_units": [512, 256],
}
with mlflow.start_run():
# Log training parameters
mlflow.log_params(params)
# Create and compile model
model = create_model()
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=params["learning_rate"]),
loss=params["loss_function"],
metrics=["accuracy"],
)
# Log model architecture
with open("model_summary.txt", "w") as f:
model.summary(print_fn=lambda x: f.write(x + "\n"))
mlflow.log_artifact("model_summary.txt")
# Custom callback for logging metrics
class MLflowCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if logs:
mlflow.log_metrics(
{
"train_loss": logs.get("loss"),
"train_accuracy": logs.get("accuracy"),
"val_loss": logs.get("val_loss"),
"val_accuracy": logs.get("val_accuracy"),
},
step=epoch,
)
# Train model with custom callback
history = model.fit(
x_train,
y_train,
batch_size=params["batch_size"],
epochs=params["epochs"],
validation_data=(x_test, y_test),
callbacks=[MLflowCallback()],
verbose=1,
)
# Evaluate model
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
mlflow.log_metrics({"test_loss": test_loss, "test_accuracy": test_accuracy})
# Log the trained model
mlflow.keras.log_model(model, name="model")
print(f"Test accuracy: {test_accuracy:.4f}")
If you run this code with a local MLflow server, you'll see comprehensive tracking in the MLflow UI.
Using MLflow's Keras Callbackβ
MLflow provides a built-in callback for Keras that simplifies experiment tracking. The mlflow.keras.MlflowCallback()
integrates seamlessly with your Keras training loop:
import mlflow
import mlflow.keras
from mlflow.keras import MlflowCallback
# Create model and prepare data (same as above)
model = create_model()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
with mlflow.start_run() as run:
# Use MLflow's built-in callback
mlflow_callback = MlflowCallback(run)
history = model.fit(
x_train,
y_train,
batch_size=128,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[mlflow_callback],
verbose=1,
)
Advanced Callback Usageβ
The MlflowCallback supports various configuration options:
# Log metrics every 5 batches instead of every epoch
mlflow_callback = MlflowCallback(run, log_every_epoch=False, log_every_n_steps=5)
# Custom callback subclass for specialized logging
class CustomMlflowCallback(MlflowCallback):
def on_epoch_end(self, epoch, logs=None):
# Call parent method
super().on_epoch_end(epoch, logs)
# Add custom logging
if logs and epoch % 5 == 0: # Log every 5 epochs
mlflow.log_metric(
"learning_rate", self.model.optimizer.learning_rate.numpy()
)
def on_train_end(self, logs=None):
# Log final model weights distribution
weights = self.model.get_weights()
avg_weight = np.mean([np.mean(w) for w in weights])
mlflow.log_metric("avg_final_weight", avg_weight)
Saving Your Keras Model to MLflowβ
Basic Model Savingβ
Save your trained Keras model using mlflow.keras.log_model()
:
import mlflow
import mlflow.keras
import numpy as np
# Train your model
model = create_model()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.fit(x_train, y_train, epochs=5)
model_info = mlflow.keras.log_model(model, name="model")
# Load and use the model
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
# Make predictions
predictions = loaded_model.predict(x_test[:5])
print("Predictions:", predictions)
Model Signaturesβ
A model signature describes a model's input and output schema. While not required, it's a best practice for better model understanding and validation:
import mlflow
from mlflow.models import infer_signature
import numpy as np
# Prepare sample data for signature inference
sample_input = x_test[:100]
sample_predictions = model.predict(sample_input)
# Infer signature from sample data
signature = infer_signature(sample_input, sample_predictions)
# Log model with signature
model_info = mlflow.keras.log_model(model, name="model", signature=signature)
You can also manually create signatures for more control:
from mlflow.types import Schema, TensorSpec
from mlflow.models import ModelSignature
import numpy as np
# Define input and output schemas
input_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 784))])
output_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 10))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
model_info = mlflow.keras.log_model(model, name="model", signature=signature)
Multi-Backend Support with Keras 3.0β
As mentioned at the start of this guide, Keras 3.0's multi-backend support is one of its most powerful features. MLflow's tracking works seamlessly across all supported backends:
import os
import mlflow
# Switch backends easily - MLflow tracking code remains identical
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"
import keras
import mlflow.keras
# Enable autologging (works with any backend)
mlflow.tensorflow.autolog()
# Your training code is backend-agnostic
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
with mlflow.start_run():
model.fit(x_train, y_train, epochs=5, validation_split=0.2)
This consistency means you can:
- Experiment with different backends without changing your MLflow tracking code
- Optimize performance by choosing the best backend for your hardware (JAX for TPUs, PyTorch for research flexibility, TensorFlow for production)
- Maintain reproducibility across different computational environments
Advanced Featuresβ
Hyperparameter Tuning with Keras and MLflowβ
Combine Keras with hyperparameter tuning libraries while tracking everything in MLflow:
import mlflow
import optuna
from sklearn.model_selection import train_test_split
def objective(trial):
# Suggest hyperparameters
lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
hidden_units = trial.suggest_int("hidden_units", 64, 512)
dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
with mlflow.start_run(nested=True):
# Log trial parameters
mlflow.log_params(
{
"learning_rate": lr,
"batch_size": batch_size,
"hidden_units": hidden_units,
"dropout_rate": dropout_rate,
}
)
# Create model with suggested parameters
model = keras.Sequential(
[
keras.layers.Dense(hidden_units, activation="relu", input_shape=(784,)),
keras.layers.Dropout(dropout_rate),
keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# Train model
history = model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=10,
validation_split=0.2,
verbose=0,
)
# Get validation accuracy
val_accuracy = max(history.history["val_accuracy"])
mlflow.log_metric("val_accuracy", val_accuracy)
return val_accuracy
# Run hyperparameter optimization
with mlflow.start_run():
mlflow.set_tag("optimization", "optuna")
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)
# Log best parameters
mlflow.log_params(study.best_params)
mlflow.log_metric("best_val_accuracy", study.best_value)
Custom Metrics and Artifactsβ
Log custom visualizations and metrics specific to your use case:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
def log_training_plots(history, run_id):
"""Log training history plots to MLflow."""
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history["loss"], label="Training Loss")
ax1.plot(history.history["val_loss"], label="Validation Loss")
ax1.set_title("Model Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax2.plot(history.history["accuracy"], label="Training Accuracy")
ax2.plot(history.history["val_accuracy"], label="Validation Accuracy")
ax2.set_title("Model Accuracy")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.legend()
plt.tight_layout()
plt.savefig("training_history.png", dpi=300, bbox_inches="tight")
mlflow.log_artifact("training_history.png")
plt.close()
def log_evaluation_metrics(model, x_test, y_test, class_names):
"""Log comprehensive evaluation metrics."""
# Get predictions
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test, axis=1)
# Confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
mlflow.log_artifact("confusion_matrix.png")
plt.close()
# Classification report
report = classification_report(
y_true_classes, y_pred_classes, target_names=class_names, output_dict=True
)
# Log per-class metrics
for class_name in class_names:
if class_name in report:
mlflow.log_metrics(
{
f"{class_name}_precision": report[class_name]["precision"],
f"{class_name}_recall": report[class_name]["recall"],
f"{class_name}_f1": report[class_name]["f1-score"],
}
)
# Usage example
with mlflow.start_run():
# Train model
history = model.fit(
x_train, y_train, validation_data=(x_test, y_test), epochs=10, verbose=1
)
# Log comprehensive results
log_training_plots(history, mlflow.active_run().info.run_id)
log_evaluation_metrics(
model, x_test, y_test, class_names=[str(i) for i in range(10)]
)
Conclusionβ
MLflow's integration with Keras provides a comprehensive solution for experiment tracking and model management in deep learning workflows. Whether you choose autologging for simplicity or manual logging for fine-grained control, MLflow captures all the essential information needed for reproducible machine learning research and production deployment.
Key benefits of using MLflow with Keras include:
- Seamless Integration: One-line autologging setup with comprehensive tracking
- Multi-Backend Support: Consistent tracking across TensorFlow, JAX, and PyTorch backends
- Flexible Logging: Choose between automatic and manual logging approaches
- Production Ready: Built-in model serving and deployment capabilities
- Collaborative Development: Share experiments and models through MLflow's intuitive UI
Whether you're conducting research experiments or building production ML systems, the MLflow-Keras integration provides the foundation for organized, reproducible, and scalable deep learning workflows.