TensorFlow within MLflow
TensorFlow is a powerful end-to-end open source platform for machine learning that has revolutionized how developers build and deploy ML solutions. With its comprehensive ecosystem of tools and libraries, TensorFlow empowers everyone from beginners to experts to create sophisticated models for diverse applications.
TensorFlow's Keras API provides an intuitive interface for building and training deep learning models, while its powerful backend enables efficient computation on CPUs, GPUs, and TPUs.
Why TensorFlow Leads the Industry
Complete ML Ecosystemโ
- ๐๏ธ Production-Ready: End-to-end platform from experimentation to deployment
- ๐ฑ Multi-Platform Deployment: Run models on browsers, mobile devices, edge hardware, and servers
- ๐ฌ Research Flexibility: High-level and low-level APIs for both beginners and experts
- ๐ TensorBoard Integration: Rich visualization of model architecture and training metrics
Powerful Core Featuresโ
- โก Graph Execution: Optimized execution for maximum performance
- ๐ Eager Execution: Immediate evaluation for intuitive debugging
- ๐งฉ Modular Design: Customize any part of your ML pipeline
- ๐ Global Community: Extensive resources, tutorials, and pre-trained models
Why MLflow + TensorFlow?โ
The integration of MLflow with TensorFlow creates a powerful workflow for machine learning practitioners:
- ๐ One-Line Autologging: Enable comprehensive tracking with just
mlflow.tensorflow.autolog()
- โ๏ธ Zero-Code Integration: Your existing TensorFlow training code works unchanged
- ๐ Complete Reproducibility: Every parameter, metric, and model is captured automatically
- ๐ Training Visualization: Monitor performance through the MLflow UI
- ๐ฅ Collaborative Development: Share experiments and results with team members
- ๐ Streamlined Deployment: Package models for deployment across different environments
Autologging TensorFlow Experimentsโ
MLflow can automatically log metrics, parameters, and models from your TensorFlow training runs. Simply call mlflow.tensorflow.autolog()
or mlflow.autolog()
before your training code:
import mlflow
import numpy as np
import tensorflow as tf
from tensorflow import keras
# Enable autologging
mlflow.tensorflow.autolog()
# Prepare sample data
data = np.random.uniform(size=[20, 28, 28, 3])
label = np.random.randint(2, size=20)
# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Training with automatic logging
with mlflow.start_run():
model.fit(data, label, batch_size=5, epochs=2)
Autologging Requirements and Limitations
Requirementsโ
- โ TensorFlow Version: Only TensorFlow >= 2.3.0 is supported
- โ
Training API: Must use the
model.fit()
Keras API - โ Run Context: Works both with and without an active MLflow run
Limitationsโ
- โ Custom Training Loops: Not supported (use manual logging instead)
- โ Older TensorFlow Versions: Not supported (use manual logging instead)
- โ Non-Keras TensorFlow: Not supported (use manual logging instead)
Autologging is only supported when you are using the model.fit()
Keras API to train
the model. Additionally, only TensorFlow >= 2.3.0 is supported. If you are using an older version
of TensorFlow or TensorFlow without Keras, please use manual logging.
What Gets Automatically Loggedโ
Comprehensive Autologging Details
Model Informationโ
- ๐ Model Summary: Complete architecture overview as returned by
model.summary()
- ๐งฉ Layer Configuration: Details of each layer in the model
- ๐ Parameter Count: Total number of trainable and non-trainable parameters
Training Parametersโ
- โ๏ธ Batch Size: Number of samples per gradient update
- ๐ข Epochs: Number of complete passes through the training dataset
- ๐งฎ Steps Per Epoch: Number of batch iterations per epoch
- ๐ Validation Steps: Number of batch iterations for validation
Optimizer Configurationโ
- ๐ง Optimizer Name: Type of optimizer used (Adam, SGD, etc.)
- ๐ Learning Rate: Step size for gradient updates
- ๐ฏ Epsilon: Small constant for numerical stability
- ๐ Other Optimizer Parameters: Beta values, momentum, etc.
Dataset Informationโ
- ๐ Dataset Shape: Input and output dimensions
- ๐ข Sample Count: Number of training and validation samples
Training Metricsโ
- ๐ Training Loss: Loss value for each epoch
- ๐ Validation Loss: Loss on validation data
- ๐ฏ Custom Metrics: Any metrics specified in
model.compile()
- ๐ Early Stopping Metrics:
stopped_epoch
,restored_epoch
, etc.
Artifactsโ
- ๐ค Saved Model: Complete model in TensorFlow SavedModel format
- ๐ TensorBoard Logs: Training and validation metrics
You can customize autologging behavior by passing arguments to mlflow.tensorflow.autolog()
:
mlflow.tensorflow.autolog(
log_models=True,
log_input_examples=True,
log_model_signatures=True,
log_every_n_steps=1,
)
How TensorFlow Autologging Works
MLflow's TensorFlow autologging uses a custom Keras callback attached to your model via monkey patching. This callback:
- Captures Initial State: At training start, logs model architecture, hyperparameters, and optimizer settings
- Monitors Training: Tracks metrics at each epoch or at specified intervals
- Records Completion: Saves the final trained model when training completes
This approach integrates seamlessly with TensorFlow's existing callback system, ensuring compatibility with your other callbacks like early stopping or learning rate scheduling.
Logging to MLflow with Keras Callbackโ
For more control over what gets logged, you can use MLflow's built-in Keras callback or create your own custom callback.
Using the Predefined Callbackโ
MLflow provides mlflow.tensorflow.MlflowCallback
that offers the same functionality as autologging but with more explicit control:
import mlflow
from tensorflow import keras
# Define and compile your model
model = keras.Sequential([...])
model.compile(...)
# Create an MLflow run and add the callback
with mlflow.start_run() as run:
model.fit(
data,
labels,
batch_size=32,
epochs=10,
callbacks=[mlflow.tensorflow.MlflowCallback(run)],
)
Callback Configuration Options
The MlflowCallback
accepts several parameters to customize logging behavior:
mlflow.tensorflow.MlflowCallback(
log_every_epoch=True, # Log metrics at the end of each epoch
log_every_n_steps=None, # Log metrics every N steps (overrides log_every_epoch)
)
- Epoch-based Logging: Set
log_every_epoch=True
(default) to log at the end of each epoch - Batch-based Logging: Set
log_every_n_steps=N
to log every N batches - Selective Model Logging: Set
log_models=False
to disable model saving
Customizing MLflow Loggingโ
You can create your own callback by subclassing keras.callbacks.Callback
to implement custom logging logic:
from tensorflow import keras
import math
import mlflow
class CustomMlflowCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
mlflow.log_metric("current_epoch", epoch)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
# Log metrics in log scale
for k, v in logs.items():
if v > 0: # Avoid log(0) or log(negative)
mlflow.log_metric(f"log_{k}", math.log(v), step=epoch)
mlflow.log_metric(k, v, step=epoch)
def on_train_end(self, logs=None):
# Log final model weights statistics
weights = self.model.get_weights()
mlflow.log_metric("total_parameters", sum(w.size for w in weights))
mlflow.log_metric(
"average_weight",
sum(w.sum() for w in weights) / sum(w.size for w in weights),
)
Keras Callback Lifecycle Hooks
Keras callbacks provide various hooks into the training process:
- Training Setup:
on_train_begin
,on_train_end
- Epoch Progress:
on_epoch_begin
,on_epoch_end
- Batch Progress:
on_batch_begin
,on_batch_end
- Validation:
on_test_begin
,on_test_end
- Prediction:
on_predict_begin
,on_predict_end
The logs
dictionary passed to these methods contains metrics like:
loss
: Training lossval_loss
: Validation loss- Any custom metrics defined in
model.compile()
For full documentation, see keras.callbacks.Callback.
Saving Your TensorFlow Model to MLflowโ
Basic Model Savingโ
If you haven't enabled autologging (which saves models automatically), you can manually save your TensorFlow model using mlflow.tensorflow.log_model()
:
import mlflow
import tensorflow as tf
from tensorflow import keras
# Define model
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Conv2D(8, 2),
keras.layers.MaxPool2D(2),
keras.layers.Flatten(),
keras.layers.Dense(2),
keras.layers.Softmax(),
]
)
# Train model (code omitted for brevity)
# Log the model to MLflow
model_info = mlflow.tensorflow.log_model(model, name="model")
# Later, load the model for inference
loaded_model = mlflow.tensorflow.load_model(model_info.model_uri)
predictions = loaded_model.predict(tf.random.uniform([1, 28, 28, 3]))
Understanding MLflow Model Saving
When you save a TensorFlow model with MLflow:
- Format Conversion: The model is converted to a generic MLflow
pyfunc
model to support deployment, loaded viamlflow.pyfunc.load_model()
- Preservation of Original Format: The model is still capable of being loaded as a native TensorFlow object via
mlflow.tensorflow.load_model()
- Metadata Creation: Model metadata is stored, including dependencies and signature
- Artifact Storage: The model is saved to the MLflow artifact store
- Loading Capability: The model can be loaded back as either a native TensorFlow model or a generic
pyfunc
model
This approach enables consistent model management regardless of the framework used.
Model Formatsโ
By default, MLflow saves TensorFlow models in the TensorFlow SavedModel format (compiled graph), which is ideal for deployment. You can also save in other formats:
# Save in H5 format (weights only)
mlflow.tensorflow.log_model(
model, name="model", keras_model_kwargs={"save_format": "h5"}
)
# Save in native Keras format
mlflow.tensorflow.log_model(
model, name="model", keras_model_kwargs={"save_format": "keras"}
)
Comparing Model Formats
TensorFlow SavedModel (Default)โ
- โ Complete Serialization: Includes model architecture, weights, and compilation information
- โ Deployment Ready: Optimized for production environments