from dataclasses import dataclass, field
from typing import Any, Optional
from mlflow.entities import Span
from mlflow.tracing.constant import SpanAttributeKey
[docs]@dataclass
class TraceData:
"""A container object that holds the spans data of a trace.
Args:
spans: List of spans that are part of the trace.
"""
spans: list[Span] = field(default_factory=list)
# NB: Custom constructor to allow passing additional kwargs for backward compatibility for
# DBX agent evaluator. Once they migrates to trace V3 schema, we can remove this.
def __init__(self, spans: Optional[list[Span]] = None, **kwargs):
self.spans = spans or []
[docs] @classmethod
def from_dict(cls, d):
if not isinstance(d, dict):
raise TypeError(f"TraceData.from_dict() expects a dictionary. Got: {type(d).__name__}")
return cls(spans=[Span.from_dict(span) for span in d.get("spans", [])])
[docs] def to_dict(self) -> dict[str, Any]:
return {"spans": [span.to_dict() for span in self.spans]}
@property
def intermediate_outputs(self) -> Optional[dict[str, Any]]:
"""
Returns intermediate outputs produced by the model or agent while handling the request.
There are mainly two flows to return intermediate outputs:
1. When a trace is generate by the `mlflow.log_trace` API,
return `intermediate_outputs` attribute of the span.
2. When a trace is created normally with a tree of spans,
aggregate the outputs of non-root spans.
"""
root_span = self._get_root_span()
if root_span and root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS):
return root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS)
if len(self.spans) > 1:
return {
span.name: span.outputs
for span in self.spans
if span.parent_id and span.outputs is not None
}
def _get_root_span(self) -> Optional[Span]:
for span in self.spans:
if span.parent_id is None:
return span
# `request` and `response` are preserved for backward compatibility with v2
@property
def request(self) -> Optional[str]:
if span := self._get_root_span():
# Accessing the OTel span directly get serialized value directly.
return span._span.attributes.get(SpanAttributeKey.INPUTS)
return None
@property
def response(self) -> Optional[str]:
if span := self._get_root_span():
# Accessing the OTel span directly get serialized value directly.
return span._span.attributes.get(SpanAttributeKey.OUTPUTS)
return None