Source code for mlflow.entities.trace_info_v2

from dataclasses import asdict, dataclass, field
from typing import Any, Optional

from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.assessment import Assessment
from mlflow.entities.trace_info import TraceInfo
from mlflow.entities.trace_location import TraceLocation
from mlflow.entities.trace_status import TraceStatus
from mlflow.protos.service_pb2 import TraceInfo as ProtoTraceInfo
from mlflow.protos.service_pb2 import TraceRequestMetadata as ProtoTraceRequestMetadata
from mlflow.protos.service_pb2 import TraceTag as ProtoTraceTag


def _truncate_request_metadata(d: dict[str, Any]) -> dict[str, str]:
    from mlflow.tracing.constant import MAX_CHARS_IN_TRACE_INFO_METADATA

    return {
        k[:MAX_CHARS_IN_TRACE_INFO_METADATA]: str(v)[:MAX_CHARS_IN_TRACE_INFO_METADATA]
        for k, v in d.items()
    }


def _truncate_tags(d: dict[str, Any]) -> dict[str, str]:
    from mlflow.tracing.constant import (
        MAX_CHARS_IN_TRACE_INFO_TAGS_KEY,
        MAX_CHARS_IN_TRACE_INFO_TAGS_VALUE,
    )

    return {
        k[:MAX_CHARS_IN_TRACE_INFO_TAGS_KEY]: str(v)[:MAX_CHARS_IN_TRACE_INFO_TAGS_VALUE]
        for k, v in d.items()
    }


[docs]@dataclass class TraceInfoV2(_MlflowObject): """Metadata about a trace. Args: request_id: id of the trace. experiment_id: id of the experiment. timestamp_ms: start time of the trace, in milliseconds. execution_time_ms: duration of the trace, in milliseconds. status: status of the trace. request_metadata: Key-value pairs associated with the trace. Request metadata are designed for immutable values like run ID associated with the trace. tags: Tags associated with the trace. Tags are designed for mutable values like trace name, that can be updated by the users after the trace is created, unlike request_metadata. """ request_id: str experiment_id: str timestamp_ms: int execution_time_ms: Optional[int] status: TraceStatus request_metadata: dict[str, str] = field(default_factory=dict) tags: dict[str, str] = field(default_factory=dict) assessments: list[Assessment] = field(default_factory=list) # NB: This field corresponds to the client request ID field in the V3 TraceInfo data model. # This is not a part of the V2 TraceInfo and not included in the dict/proto conversion, # but only added for storing the client request ID during the trace generation process, # until the internal logic e.g. InMemoryTraceManager migrates to use the V3 TraceInfo. client_request_id: Optional[str] = None def __eq__(self, other): if type(other) is type(self): return self.__dict__ == other.__dict__ return False @property def trace_id(self) -> str: """Returns the trace ID of the trace info.""" return self.request_id
[docs] def to_proto(self): proto = ProtoTraceInfo() proto.request_id = self.request_id proto.experiment_id = self.experiment_id proto.timestamp_ms = self.timestamp_ms # NB: Proto setter does not support nullable fields (even with 'optional' keyword), # so we substitute None with 0 for execution_time_ms. This should be not too confusing # as we only put None when starting a trace i.e. the execution time is actually 0. proto.execution_time_ms = self.execution_time_ms or 0 proto.status = self.status.to_proto() request_metadata = [] for key, value in _truncate_request_metadata(self.request_metadata).items(): attr = ProtoTraceRequestMetadata() attr.key = key attr.value = value request_metadata.append(attr) proto.request_metadata.extend(request_metadata) tags = [] for key, value in _truncate_tags(self.tags).items(): tag = ProtoTraceTag() tag.key = key tag.value = str(value) tags.append(tag) proto.tags.extend(tags) return proto
[docs] @classmethod def from_proto(cls, proto, assessments=None): return cls( request_id=proto.request_id, experiment_id=proto.experiment_id, timestamp_ms=proto.timestamp_ms, execution_time_ms=proto.execution_time_ms, status=TraceStatus.from_proto(proto.status), request_metadata={attr.key: attr.value for attr in proto.request_metadata}, tags={tag.key: tag.value for tag in proto.tags}, assessments=assessments or [], )
[docs] def to_dict(self): """ Convert trace info to a dictionary for persistence. Update status field to the string value for serialization. """ trace_info_dict = asdict(self) trace_info_dict["status"] = self.status.value # Client request ID field is only added for internal use, and should not be # serialized for V2 TraceInfo. trace_info_dict.pop("client_request_id", None) return trace_info_dict
[docs] @classmethod def from_dict(cls, trace_info_dict): """ Convert trace info dictionary to TraceInfo object. """ if "status" not in trace_info_dict: raise ValueError("status is required in trace info dictionary.") trace_info_dict["status"] = TraceStatus(trace_info_dict["status"]) return cls(**trace_info_dict)
[docs] def to_v3(self, request: Optional[str] = None, response: Optional[str] = None) -> TraceInfo: return TraceInfo( trace_id=self.request_id, client_request_id=self.client_request_id, trace_location=TraceLocation.from_experiment_id(self.experiment_id), request_preview=request, response_preview=response, request_time=self.timestamp_ms, execution_duration=self.execution_time_ms, state=self.status.to_state(), trace_metadata=self.request_metadata, tags=self.tags, assessments=self.assessments, )