from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.exceptions import MlflowException
from mlflow.protos import service_pb2 as pb
[docs]@dataclass
class MlflowExperimentLocation(_MlflowObject):
"""
Represents the location of an MLflow experiment.
Args:
experiment_id: The ID of the MLflow experiment where the trace is stored.
"""
experiment_id: str
[docs] def to_proto(self):
return pb.TraceLocation.MlflowExperimentLocation(experiment_id=self.experiment_id)
[docs] @classmethod
def from_proto(cls, proto) -> "MlflowExperimentLocation":
return cls(experiment_id=proto.experiment_id)
[docs] def to_dict(self) -> dict[str, Any]:
return {"experiment_id": self.experiment_id}
[docs] @classmethod
def from_dict(cls, d: dict[str, Any]) -> "MlflowExperimentLocation":
return cls(experiment_id=d["experiment_id"])
[docs]@dataclass
class InferenceTableLocation(_MlflowObject):
"""
Represents the location of a Databricks inference table.
Args:
full_table_name: The fully qualified name of the inference table where
the trace is stored, in the format of `<catalog>.<schema>.<table>`.
"""
full_table_name: str
[docs] def to_proto(self):
return pb.TraceLocation.InferenceTableLocation(full_table_name=self.full_table_name)
[docs] @classmethod
def from_proto(cls, proto) -> "InferenceTableLocation":
return cls(full_table_name=proto.full_table_name)
[docs] def to_dict(self) -> dict[str, Any]:
return {"full_table_name": self.full_table_name}
[docs] @classmethod
def from_dict(cls, d: dict[str, Any]) -> "InferenceTableLocation":
return cls(full_table_name=d["full_table_name"])
[docs]class TraceLocationType(str, Enum):
TRACE_LOCATION_TYPE_UNSPECIFIED = "TRACE_LOCATION_TYPE_UNSPECIFIED"
MLFLOW_EXPERIMENT = "MLFLOW_EXPERIMENT"
INFERENCE_TABLE = "INFERENCE_TABLE"
[docs] def to_proto(self):
return pb.TraceLocation.TraceLocationType.Value(self)
[docs] @classmethod
def from_proto(cls, proto: int) -> "TraceLocationType":
return TraceLocationType(pb.TraceLocation.TraceLocationType.Name(proto))
[docs] @classmethod
def from_dict(cls, d: dict[str, Any]) -> "TraceLocationType":
return cls(d["type"])
[docs]@dataclass
class TraceLocation(_MlflowObject):
"""
Represents the location where the trace is stored.
Currently, MLflow supports two types of trace locations:
- MLflow experiment: The trace is stored in an MLflow experiment.
- Inference table: The trace is stored in a Databricks inference table.
Args:
type: The type of the trace location, should be one of the
:py:class:`TraceLocationType` enum values.
mlflow_experiment: The MLflow experiment location. Set this when the
location type is MLflow experiment.
inference_table: The inference table location. Set this when the
location type is Databricks Inference table.
"""
type: TraceLocationType
mlflow_experiment: Optional[MlflowExperimentLocation] = None
inference_table: Optional[InferenceTableLocation] = None
def __post_init__(self) -> None:
if self.mlflow_experiment is not None and self.inference_table is not None:
raise MlflowException.invalid_parameter_value(
"Only one of mlflow_experiment or inference_table can be provided."
)
if (self.mlflow_experiment and self.type != TraceLocationType.MLFLOW_EXPERIMENT) or (
self.inference_table and self.type != TraceLocationType.INFERENCE_TABLE
):
raise MlflowException.invalid_parameter_value(
f"Trace location type {type} does not match the provided location "
f"{self.mlflow_experiment or self.inference_table}."
)
[docs] def to_dict(self) -> dict[str, Any]:
d = {"type": self.type.value}
if self.mlflow_experiment:
d["mlflow_experiment"] = self.mlflow_experiment.to_dict()
elif self.inference_table:
d["inference_table"] = self.inference_table.to_dict()
return d
[docs] @classmethod
def from_dict(cls, d: dict[str, Any]) -> "TraceLocation":
return cls(
type=TraceLocationType(d["type"]),
mlflow_experiment=(
MlflowExperimentLocation.from_dict(v) if (v := d.get("mlflow_experiment")) else None
),
inference_table=(
InferenceTableLocation.from_dict(v) if (v := d.get("inference_table")) else None
),
)
[docs] def to_proto(self):
if self.mlflow_experiment:
return pb.TraceLocation(
type=self.type.to_proto(),
mlflow_experiment=self.mlflow_experiment.to_proto(),
)
elif self.inference_table:
return pb.TraceLocation(
type=self.type.to_proto(),
inference_table=self.inference_table.to_proto(),
)
else:
return pb.TraceLocation(type=self.type.to_proto())
[docs] @classmethod
def from_proto(cls, proto) -> "TraceLocation":
type_ = TraceLocationType.from_proto(proto.type)
if proto.WhichOneof("identifier") == "mlflow_experiment":
return cls(
type=type_,
mlflow_experiment=MlflowExperimentLocation.from_proto(proto.mlflow_experiment),
)
elif proto.WhichOneof("identifier") == "inference_table":
return cls(
type=type_,
inference_table=InferenceTableLocation.from_proto(proto.inference_table),
)
else:
return cls(type=type_)
[docs] @classmethod
def from_experiment_id(cls, experiment_id: str) -> "TraceLocation":
return cls(
type=TraceLocationType.MLFLOW_EXPERIMENT,
mlflow_experiment=MlflowExperimentLocation(experiment_id=experiment_id),
)