import json
import logging
from functools import cached_property
from inspect import isclass
from typing import Any, Final, Optional, TypedDict, Union
import polars as pl
from polars.datatypes.classes import DataType as PolarsDataType
from polars.datatypes.classes import DataTypeClass as PolarsDataTypeClass
from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.types.schema import Array, ColSpec, DataType, Object, Property, Schema
_logger = logging.getLogger(__name__)
def hash_polars_df(df: pl.DataFrame) -> str:
# probably not the best way to hash, also see:
# https://github.com/pola-rs/polars/issues/9743
# https://stackoverflow.com/q/76678160
return str(df.hash_rows().sum())
ColSpecType = Union[DataType, Array, Object, str]
TYPE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
pl.Binary: DataType.binary,
pl.Boolean: DataType.boolean,
pl.Datetime: DataType.datetime,
pl.Float32: DataType.float,
pl.Float64: DataType.double,
pl.Int8: DataType.integer,
pl.Int16: DataType.integer,
pl.Int32: DataType.integer,
pl.Int64: DataType.long,
pl.String: DataType.string,
pl.Utf8: DataType.string,
}
CLOSE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
pl.Categorical: DataType.string,
pl.Enum: DataType.string,
pl.Date: DataType.datetime,
pl.UInt8: DataType.integer,
pl.UInt16: DataType.integer,
pl.UInt32: DataType.long,
}
# Remaining types:
# pl.Decimal
# pl.UInt64
# pl.Duration
# pl.Time
# pl.Null
# pl.Object
# pl.Unknown
def infer_schema(df: pl.DataFrame) -> Schema:
return Schema([infer_colspec(df[col]) for col in df.columns])
def infer_colspec(col: pl.Series, *, allow_unknown: bool = True) -> ColSpec:
return ColSpec(
type=infer_dtype(col.dtype, col.name, allow_unknown=allow_unknown),
name=col.name,
required=col.count() > 0,
)
def infer_dtype(
dtype: Union[PolarsDataType, PolarsDataTypeClass], col_name: str, *, allow_unknown: bool
) -> ColSpecType:
cls: PolarsDataTypeClass = dtype if isinstance(dtype, PolarsDataTypeClass) else type(dtype)
mapped = TYPE_MAP.get(cls)
if mapped is not None:
return mapped
mapped = CLOSE_MAP.get(cls)
if mapped is not None:
logging.warning(
"Data type of Column '%s' contains dtype=%s which will be mapped to %s."
" This is not an exact match but is close enough",
col_name,
dtype,
mapped,
)
return mapped
if not isinstance(dtype, PolarsDataType):
return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)
if isinstance(dtype, (pl.Array, pl.List)):
# cannot check inner if not instantiated
if isclass(dtype):
if not allow_unknown:
_raise_unknown_type(dtype)
return Array("Unknown")
inner = (
"Unknown"
if dtype.inner is None
else infer_dtype(dtype.inner, f"{col_name}.[]", allow_unknown=allow_unknown)
)
return Array(inner)
if isinstance(dtype, pl.Struct):
# cannot check fields if not instantiated
if isclass(dtype):
if not allow_unknown:
_raise_unknown_type(dtype)
return Object([])
return Object(
[
Property(
name=field.name,
dtype=infer_dtype(
field.dtype, f"{col_name}.{field.name}", allow_unknown=allow_unknown
),
)
for field in dtype.fields
]
)
return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)
def _handle_unknown_dtype(dtype: Any, col_name: str, *, allow_unknown: bool) -> str:
if not allow_unknown:
_raise_unknown_type(dtype)
logging.warning(
"Data type of Columns '%s' contains dtype=%s, which cannot be mapped to any DataType",
col_name,
dtype,
)
return str(dtype)
def _raise_unknown_type(dtype: Any) -> None:
msg = f"Unknown type: {dtype!r}"
raise ValueError(msg)
[docs]class PolarsDataset(Dataset, PyFuncConvertibleDatasetMixin):
"""A polars DataFrame for use with MLflow Tracking."""
def __init__(
self,
df: pl.DataFrame,
source: DatasetSource,
targets: Optional[str] = None,
name: Optional[str] = None,
digest: Optional[str] = None,
predictions: Optional[str] = None,
) -> None:
"""
Args:
df: A polars DataFrame.
source: Source of the DataFrame.
targets: Name of the target column. Optional.
name: Name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically
generated.
digest: Digest (hash, fingerprint) of the dataset. If unspecified, a digest is
automatically computed.
predictions: Name of the column containing model predictions, if the dataset contains
model predictions. Optional. If specified, this column must be present in ``df``.
"""
if targets is not None and targets not in df.columns:
raise MlflowException(
f"DataFrame does not contain specified targets column: '{targets}'",
INVALID_PARAMETER_VALUE,
)
if predictions is not None and predictions not in df.columns:
raise MlflowException(
f"DataFrame does not contain specified predictions column: '{predictions}'",
INVALID_PARAMETER_VALUE,
)
# _df needs to be set before super init, as it is used in _compute_digest
# see Dataset.__init__()
self._df = df
super().__init__(source=source, name=name, digest=digest)
self._targets = targets
self._predictions = predictions
def _compute_digest(self) -> str:
"""Compute a digest for the dataset.
Called if the user doesn't supply a digest when constructing the dataset.
"""
return hash_polars_df(self._df)
[docs] class PolarsDatasetConfig(TypedDict):
name: str
digest: str
source: str
source_type: str
schema: str
profile: str
[docs] def to_dict(self) -> PolarsDatasetConfig:
"""Create config dictionary for the dataset.
Return a string dictionary containing the following fields: name, digest, source,
source type, schema, and profile.
"""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()} if self.schema else None)
return {
"name": self.name,
"digest": self.digest,
"source": self.source.to_json(),
"source_type": self.source._get_source_type(),
"schema": schema,
"profile": json.dumps(self.profile),
}
@property
def df(self) -> pl.DataFrame:
"""Underlying DataFrame."""
return self._df
@property
def source(self) -> DatasetSource:
"""Source of the dataset."""
return self._source
@property
def targets(self) -> Optional[str]:
"""Name of the target column.
May be ``None`` if no target column is available.
"""
return self._targets
@property
def predictions(self) -> Optional[str]:
"""Name of the predictions column.
May be ``None`` if no predictions column is available.
"""
return self._predictions
[docs] class PolarsDatasetProfile(TypedDict):
num_rows: int
num_elements: int
@property
def profile(self) -> PolarsDatasetProfile:
"""Profile of the dataset."""
return {
"num_rows": self._df.height,
"num_elements": self._df.height * self._df.width,
}
@cached_property
def schema(self) -> Optional[Schema]:
"""Instance of :py:class:`mlflow.types.Schema` representing the tabular dataset.
May be ``None`` if the schema cannot be inferred from the dataset.
"""
try:
return infer_schema(self._df)
except Exception as e:
_logger.warning("Failed to infer schema for PolarsDataset. Exception: %s", e)
return None
def to_pyfunc(self) -> PyFuncInputsOutputs:
"""Convert dataset to a collection of pyfunc inputs and outputs for model evaluation."""
if self._targets:
inputs = self._df.drop(*self._targets)
outputs = self._df.select(self._targets).to_series()
return PyFuncInputsOutputs([inputs.to_pandas()], [outputs.to_pandas()])
else:
return PyFuncInputsOutputs([self._df.to_pandas()])
def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset:
"""Convert dataset to an EvaluationDataset for model evaluation."""
return EvaluationDataset(
data=self._df.to_pandas(),
targets=self._targets,
path=path,
feature_names=feature_names,
predictions=self._predictions,
)
[docs]def from_polars(
df: pl.DataFrame,
source: Union[str, DatasetSource, None] = None,
targets: Optional[str] = None,
name: Optional[str] = None,
digest: Optional[str] = None,
predictions: Optional[str] = None,
) -> PolarsDataset:
"""Construct a :py:class:`PolarsDataset <mlflow.data.polars_dataset.PolarsDataset>` instance.
Args:
df: A polars DataFrame.
source: Source from which the DataFrame was derived, e.g. a filesystem
path, an S3 URI, an HTTPS URL, a delta table name with version, or
spark table etc. ``source`` may be specified as a URI, a path-like string,
or an instance of
:py:class:`DatasetSource <mlflow.data.dataset_source.DatasetSource>`.
If unspecified, the source is assumed to be the code location
(e.g. notebook cell, script, etc.) where
:py:func:`from_polars <mlflow.data.from_polars>` is being called.
targets: An optional target column name for supervised training. This column
must be present in ``df``.
name: Name of the dataset. If unspecified, a name is generated.
digest: Dataset digest (hash). If unspecified, a digest is computed
automatically.
predictions: An optional predictions column name for model evaluation. This column
must be present in ``df``.
.. code-block:: python
:test:
:caption: Example
import mlflow
import polars as pl
x = pl.DataFrame(
[["tom", 10, 1, 1], ["nick", 15, 0, 1], ["julie", 14, 1, 1]],
schema=["Name", "Age", "Label", "ModelOutput"],
)
dataset = mlflow.data.from_polars(x, targets="Label", predictions="ModelOutput")
"""
from mlflow.data.code_dataset_source import CodeDatasetSource
from mlflow.data.dataset_source_registry import resolve_dataset_source
from mlflow.tracking.context import registry
if source is not None:
if isinstance(source, DatasetSource):
resolved_source = source
else:
resolved_source = resolve_dataset_source(source)
else:
context_tags = registry.resolve_tags()
resolved_source = CodeDatasetSource(tags=context_tags)
return PolarsDataset(
df=df,
source=resolved_source,
targets=targets,
name=name,
digest=digest,
predictions=predictions,
)