Deep Learning with MLflow 3.0
In this example, we demonstrate how to use MLflow 3.0 to track and evaluate deep learning models with a PyTorch-based Iris classifier.
The example showcases how to log checkpoints, link metrics to datasets, and rank models based on performance metrics.
With mlflow.search_logged_models()
you can easily find the best model based on the metric value.
import pandas as pd
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.pytorch
from mlflow.entities import Dataset
# Helper function to prepare data
def prepare_data(df):
X = torch.tensor(df.iloc[:, :-1].values, dtype=torch.float32)
y = torch.tensor(df.iloc[:, -1].values, dtype=torch.long)
return X, y
# Helper function to compute accuracy
def compute_accuracy(model, X, y):
with torch.no_grad():
outputs = model(X)
_, predicted = torch.max(outputs, 1)
accuracy = (predicted == y).sum().item() / y.size(0)
return accuracy
# Define a basic PyTorch classifier
class IrisClassifier(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Load Iris dataset and prepare the DataFrame
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df["target"] = iris.target
# Split into training and testing datasets
train_df, test_df = train_test_split(iris_df, test_size=0.2, random_state=42)
# Prepare training data
train_dataset = mlflow.data.from_pandas(train_df, name="train")
X_train, y_train = prepare_data(train_dataset.df)
# Define the PyTorch model and move it to the device
input_size = X_train.shape[1]
hidden_size = 16
output_size = len(iris.target_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scripted_model = IrisClassifier(input_size, hidden_size, output_size).to(device)
scripted_model = torch.jit.script(scripted_model)
# Start a run to represent the training job
with mlflow.start_run() as run:
# Load the training dataset with MLflow. We will link training metrics to this dataset.
train_dataset: Dataset = mlflow.data.from_pandas(train_df, name="train")
X_train, y_train = prepare_data(train_dataset.df)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(scripted_model.parameters(), lr=0.01)
for epoch in range(101):
X_train, y_train = X_train.to(device), y_train.to(device)
out = scripted_model(X_train)
loss = criterion(out, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log a checkpoint with metrics every 10 epochs
if epoch % 10 == 0:
# Each newly created LoggedModel checkpoint is linked with its name and step
model_info = mlflow.pytorch.log_model(
pytorch_model=scripted_model,
name=f"torch-iris-{epoch}",
step=epoch,
input_example=X_train.numpy(),
)
# log params to the run, LoggedModel inherits those params
mlflow.log_params(
params={
"n_layers": 3,
"activation": "ReLU",
"criterion": "CrossEntropyLoss",
"optimizer": "Adam",
}
)
# Log metric on training dataset at step and link to LoggedModel
mlflow.log_metric(
key="accuracy",
value=compute_accuracy(scripted_model, X_train, y_train),
step=epoch,
model_id=model_info.model_id,
dataset=train_dataset,
)
In the run page, you can see the logged models generated, and the model names follow the pattern of torch-iris-<epoch>
:
The logged models also show up in the Models tab of the experiment, including their dataset, parameters and metrics:
Use mlflow.search_logged_models()
to search the logged models attached to the run, ordering them by the accuracy
metric value to easily fetch the best and worst models:
ranked_checkpoints = mlflow.search_logged_models(
filter_string=f"source_run_id='{run.info.run_id}'",
order_by=[{"field_name": "metrics.accuracy", "ascending": False}],
output_format="list",
)
best_checkpoint = ranked_checkpoints[0]
print(f"Best model: {best_checkpoint}")
print(best_checkpoint.metrics)
# Best model: <LoggedModel: artifact_location='file:///Users/serena.ruan/Documents/repos/mlflow-3-doc/mlruns/0/models/41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf/artifacts', creation_timestamp=1743734069924, experiment_id='0', last_updated_timestamp=1743734075018, metrics=[<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=100, timestamp=1743734075029, value=0.975>], model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', model_type='', model_uri='models:/41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', name='torch-iris-100', params={'activation': 'ReLU',
# 'criterion': 'CrossEntropyLoss',
# 'n_layers': '3',
# 'optimizer': 'Adam'}, source_run_id='12f143a7fda1461e9240d7ffad4ea5bd', status=<LoggedModelStatus.READY: 'READY'>, status_message='', tags={'mlflow.source.git.commit': '7324c807f07a1766d4b951733e3d723504b4576e',
# 'mlflow.source.name': 'a.py',
# 'mlflow.source.type': 'LOCAL',
# 'mlflow.user': 'serena.ruan'}>
# [<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=100, timestamp=1743734075029, value=0.975>]
worst_checkpoint = ranked_checkpoints[-1]
print(f"Worst model: {worst_checkpoint}")
print(worst_checkpoint.metrics)
# Worst model: <LoggedModel: artifact_location='file:///Users/serena.ruan/Documents/repos/mlflow-3-doc/mlruns/0/models/0d789084-9a3b-4b85-9d43-6a148c014b7e/artifacts', creation_timestamp=1743734016290, experiment_id='0', last_updated_timestamp=1743734022728, metrics=[<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=0, timestamp=1743734022737, value=0.3>], model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', model_type='', model_uri='models:/0d789084-9a3b-4b85-9d43-6a148c014b7e', name='torch-iris-0', params={}, source_run_id='12f143a7fda1461e9240d7ffad4ea5bd', status=<LoggedModelStatus.READY: 'READY'>, status_message='', tags={'mlflow.source.git.commit': '7324c807f07a1766d4b951733e3d723504b4576e',
# 'mlflow.source.name': 'a.py',
# 'mlflow.source.type': 'LOCAL',
# 'mlflow.user': 'serena.ruan'}>
# [<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=0, timestamp=1743734022737, value=0.3>]
Artifacts of the model can be viewed on the Artifacts tab of the model page: