from typing import Any, Optional, Union
import mlflow.protos.service_pb2 as pb2
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.logged_model_parameter import LoggedModelParameter
from mlflow.entities.logged_model_status import LoggedModelStatus
from mlflow.entities.logged_model_tag import LoggedModelTag
from mlflow.entities.metric import Metric
[docs]class LoggedModel(_MlflowObject):
"""
MLflow entity representing a Model logged to an MLflow Experiment.
"""
def __init__(
self,
experiment_id: str,
model_id: str,
name: str,
artifact_location: str,
creation_timestamp: int,
last_updated_timestamp: int,
model_type: Optional[str] = None,
source_run_id: Optional[str] = None,
status: Union[LoggedModelStatus, int] = LoggedModelStatus.READY,
status_message: Optional[str] = None,
tags: Optional[Union[list[LoggedModelTag], dict[str, str]]] = None,
params: Optional[Union[list[LoggedModelParameter], dict[str, str]]] = None,
metrics: Optional[list[Metric]] = None,
):
super().__init__()
self._experiment_id: str = experiment_id
self._model_id: str = model_id
self._name: str = name
self._artifact_location: str = artifact_location
self._creation_time: int = creation_timestamp
self._last_updated_timestamp: int = last_updated_timestamp
self._model_type: Optional[str] = model_type
self._source_run_id: Optional[str] = source_run_id
self._status: LoggedModelStatus = (
status if isinstance(status, LoggedModelStatus) else LoggedModelStatus.from_int(status)
)
self._status_message: Optional[str] = status_message
self._tags: dict[str, str] = (
{tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else (tags or {})
)
self._params: dict[str, str] = (
{param.key: param.value for param in (params or [])}
if isinstance(params, list)
else (params or {})
)
self._metrics: Optional[list[Metric]] = metrics
self._model_uri = f"models:/{self.model_id}"
def __repr__(self) -> str:
return "LoggedModel({})".format(
", ".join(
f"{k}={v!r}"
for k, v in sorted(self, key=lambda x: x[0])
if (
k
not in [
# These fields can be large and take up space on the notebook or terminal
"tags",
"params",
"metrics",
]
)
)
)
@property
def experiment_id(self) -> str:
"""String. Experiment ID associated with this Model."""
return self._experiment_id
@experiment_id.setter
def experiment_id(self, new_experiment_id: str):
self._experiment_id = new_experiment_id
@property
def model_id(self) -> str:
"""String. Unique ID for this Model."""
return self._model_id
@model_id.setter
def model_id(self, new_model_id: str):
self._model_id = new_model_id
@property
def name(self) -> str:
"""String. Name for this Model."""
return self._name
@name.setter
def name(self, new_name: str):
self._name = new_name
@property
def artifact_location(self) -> str:
"""String. Location of the model artifacts."""
return self._artifact_location
@artifact_location.setter
def artifact_location(self, new_artifact_location: str):
self._artifact_location = new_artifact_location
@property
def creation_timestamp(self) -> int:
"""Integer. Model creation timestamp (milliseconds since the Unix epoch)."""
return self._creation_time
@property
def last_updated_timestamp(self) -> int:
"""Integer. Timestamp of last update for this Model (milliseconds since the Unix
epoch).
"""
return self._last_updated_timestamp
@last_updated_timestamp.setter
def last_updated_timestamp(self, updated_timestamp: int):
self._last_updated_timestamp = updated_timestamp
@property
def model_type(self) -> Optional[str]:
"""String. Type of the model."""
return self._model_type
@model_type.setter
def model_type(self, new_model_type: Optional[str]):
self._model_type = new_model_type
@property
def source_run_id(self) -> Optional[str]:
"""String. MLflow run ID that generated this model."""
return self._source_run_id
@property
def status(self) -> LoggedModelStatus:
"""String. Current status of this Model."""
return self._status
@status.setter
def status(self, updated_status: str):
self._status = updated_status
@property
def status_message(self) -> Optional[str]:
"""String. Descriptive message for error status conditions."""
return self._status_message
@property
def tags(self) -> dict[str, str]:
"""Dictionary of tag key (string) -> tag value for this Model."""
return self._tags
@property
def params(self) -> dict[str, str]:
"""Model parameters."""
return self._params
@property
def metrics(self) -> Optional[list[Metric]]:
"""List of metrics associated with this Model."""
return self._metrics
@property
def model_uri(self) -> str:
"""URI of the model."""
return self._model_uri
@metrics.setter
def metrics(self, new_metrics: Optional[list[Metric]]):
self._metrics = new_metrics
@classmethod
def _properties(cls) -> list[str]:
# aggregate with base class properties since cls.__dict__ does not do it automatically
return sorted(cls._get_properties_helper())
def _add_tag(self, tag):
self._tags[tag.key] = tag.value
[docs] def to_dictionary(self) -> dict[str, Any]:
model_dict = dict(self)
model_dict["status"] = self.status.to_int()
# Remove the model_uri field from the dictionary since it is a derived field
del model_dict["model_uri"]
return model_dict
[docs] def to_proto(self):
return pb2.LoggedModel(
info=pb2.LoggedModelInfo(
experiment_id=self.experiment_id,
model_id=self.model_id,
name=self.name,
artifact_uri=self.artifact_location,
creation_timestamp_ms=self.creation_timestamp,
last_updated_timestamp_ms=self.last_updated_timestamp,
model_type=self.model_type,
source_run_id=self.source_run_id,
status=self.status.to_proto(),
tags=[pb2.LoggedModelTag(key=k, value=v) for k, v in self.tags.items()],
),
data=pb2.LoggedModelData(
params=[pb2.LoggedModelParameter(key=k, value=v) for (k, v) in self.params.items()],
metrics=[m.to_proto() for m in self.metrics] if self.metrics else [],
),
)
[docs] @classmethod
def from_proto(cls, proto):
return cls(
experiment_id=proto.info.experiment_id,
model_id=proto.info.model_id,
name=proto.info.name,
artifact_location=proto.info.artifact_uri,
creation_timestamp=proto.info.creation_timestamp_ms,
last_updated_timestamp=proto.info.last_updated_timestamp_ms,
model_type=proto.info.model_type,
source_run_id=proto.info.source_run_id,
status=LoggedModelStatus.from_proto(proto.info.status),
status_message=proto.info.status_message,
tags=[LoggedModelTag.from_proto(tag) for tag in proto.info.tags],
params=[LoggedModelParameter.from_proto(param) for param in proto.data.params],
metrics=[Metric.from_proto(metric) for metric in proto.data.metrics],
)