Source code for mlflow.data.polars_dataset

import json
import logging
from functools import cached_property
from inspect import isclass
from typing import Any, Final, Optional, TypedDict, Union

import polars as pl
from polars.datatypes.classes import DataType as PolarsDataType
from polars.datatypes.classes import DataTypeClass as PolarsDataTypeClass

from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.types.schema import Array, ColSpec, DataType, Object, Property, Schema

_logger = logging.getLogger(__name__)


def hash_polars_df(df: pl.DataFrame) -> str:
    # probably not the best way to hash, also see:
    # https://github.com/pola-rs/polars/issues/9743
    # https://stackoverflow.com/q/76678160
    return str(df.hash_rows().sum())


ColSpecType = Union[DataType, Array, Object, str]
TYPE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
    pl.Binary: DataType.binary,
    pl.Boolean: DataType.boolean,
    pl.Datetime: DataType.datetime,
    pl.Float32: DataType.float,
    pl.Float64: DataType.double,
    pl.Int8: DataType.integer,
    pl.Int16: DataType.integer,
    pl.Int32: DataType.integer,
    pl.Int64: DataType.long,
    pl.String: DataType.string,
    pl.Utf8: DataType.string,
}
CLOSE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
    pl.Categorical: DataType.string,
    pl.Enum: DataType.string,
    pl.Date: DataType.datetime,
    pl.UInt8: DataType.integer,
    pl.UInt16: DataType.integer,
    pl.UInt32: DataType.long,
}
# Remaining types:
# pl.Decimal
# pl.UInt64
# pl.Duration
# pl.Time
# pl.Null
# pl.Object
# pl.Unknown


def infer_schema(df: pl.DataFrame) -> Schema:
    return Schema([infer_colspec(df[col]) for col in df.columns])


def infer_colspec(col: pl.Series, *, allow_unknown: bool = True) -> ColSpec:
    return ColSpec(
        type=infer_dtype(col.dtype, col.name, allow_unknown=allow_unknown),
        name=col.name,
        required=col.count() > 0,
    )


def infer_dtype(
    dtype: Union[PolarsDataType, PolarsDataTypeClass], col_name: str, *, allow_unknown: bool
) -> ColSpecType:
    cls: PolarsDataTypeClass = dtype if isinstance(dtype, PolarsDataTypeClass) else type(dtype)
    mapped = TYPE_MAP.get(cls)
    if mapped is not None:
        return mapped

    mapped = CLOSE_MAP.get(cls)
    if mapped is not None:
        logging.warning(
            "Data type of Column '%s' contains dtype=%s which will be mapped to %s."
            " This is not an exact match but is close enough",
            col_name,
            dtype,
            mapped,
        )
        return mapped

    if not isinstance(dtype, PolarsDataType):
        return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)

    if isinstance(dtype, (pl.Array, pl.List)):
        # cannot check inner if not instantiated
        if isclass(dtype):
            if not allow_unknown:
                _raise_unknown_type(dtype)
            return Array("Unknown")

        inner = (
            "Unknown"
            if dtype.inner is None
            else infer_dtype(dtype.inner, f"{col_name}.[]", allow_unknown=allow_unknown)
        )
        return Array(inner)

    if isinstance(dtype, pl.Struct):
        # cannot check fields if not instantiated
        if isclass(dtype):
            if not allow_unknown:
                _raise_unknown_type(dtype)
            return Object([])

        return Object(
            [
                Property(
                    name=field.name,
                    dtype=infer_dtype(
                        field.dtype, f"{col_name}.{field.name}", allow_unknown=allow_unknown
                    ),
                )
                for field in dtype.fields
            ]
        )

    return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)


def _handle_unknown_dtype(dtype: Any, col_name: str, *, allow_unknown: bool) -> str:
    if not allow_unknown:
        _raise_unknown_type(dtype)

    logging.warning(
        "Data type of Columns '%s' contains dtype=%s, which cannot be mapped to any DataType",
        col_name,
        dtype,
    )
    return str(dtype)


def _raise_unknown_type(dtype: Any) -> None:
    msg = f"Unknown type: {dtype!r}"
    raise ValueError(msg)


[docs]class PolarsDataset(Dataset, PyFuncConvertibleDatasetMixin): """A polars DataFrame for use with MLflow Tracking.""" def __init__( self, df: pl.DataFrame, source: DatasetSource, targets: Optional[str] = None, name: Optional[str] = None, digest: Optional[str] = None, predictions: Optional[str] = None, ) -> None: """ Args: df: A polars DataFrame. source: Source of the DataFrame. targets: Name of the target column. Optional. name: Name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically generated. digest: Digest (hash, fingerprint) of the dataset. If unspecified, a digest is automatically computed. predictions: Name of the column containing model predictions, if the dataset contains model predictions. Optional. If specified, this column must be present in ``df``. """ if targets is not None and targets not in df.columns: raise MlflowException( f"DataFrame does not contain specified targets column: '{targets}'", INVALID_PARAMETER_VALUE, ) if predictions is not None and predictions not in df.columns: raise MlflowException( f"DataFrame does not contain specified predictions column: '{predictions}'", INVALID_PARAMETER_VALUE, ) # _df needs to be set before super init, as it is used in _compute_digest # see Dataset.__init__() self._df = df super().__init__(source=source, name=name, digest=digest) self._targets = targets self._predictions = predictions def _compute_digest(self) -> str: """Compute a digest for the dataset. Called if the user doesn't supply a digest when constructing the dataset. """ return hash_polars_df(self._df)
[docs] class PolarsDatasetConfig(TypedDict): name: str digest: str source: str source_type: str schema: str profile: str
[docs] def to_dict(self) -> PolarsDatasetConfig: """Create config dictionary for the dataset. Return a string dictionary containing the following fields: name, digest, source, source type, schema, and profile. """ schema = json.dumps({"mlflow_colspec": self.schema.to_dict()} if self.schema else None) return { "name": self.name, "digest": self.digest, "source": self.source.to_json(), "source_type": self.source._get_source_type(), "schema": schema, "profile": json.dumps(self.profile), }
@property def df(self) -> pl.DataFrame: """Underlying DataFrame.""" return self._df @property def source(self) -> DatasetSource: """Source of the dataset.""" return self._source @property def targets(self) -> Optional[str]: """Name of the target column. May be ``None`` if no target column is available. """ return self._targets @property def predictions(self) -> Optional[str]: """Name of the predictions column. May be ``None`` if no predictions column is available. """ return self._predictions
[docs] class PolarsDatasetProfile(TypedDict): num_rows: int num_elements: int
@property def profile(self) -> PolarsDatasetProfile: """Profile of the dataset.""" return { "num_rows": self._df.height, "num_elements": self._df.height * self._df.width, } @cached_property def schema(self) -> Optional[Schema]: """Instance of :py:class:`mlflow.types.Schema` representing the tabular dataset. May be ``None`` if the schema cannot be inferred from the dataset. """ try: return infer_schema(self._df) except Exception as e: _logger.warning("Failed to infer schema for PolarsDataset. Exception: %s", e) return None def to_pyfunc(self) -> PyFuncInputsOutputs: """Convert dataset to a collection of pyfunc inputs and outputs for model evaluation.""" if self._targets: inputs = self._df.drop(*self._targets) outputs = self._df.select(self._targets).to_series() return PyFuncInputsOutputs([inputs.to_pandas()], [outputs.to_pandas()]) else: return PyFuncInputsOutputs([self._df.to_pandas()]) def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset: """Convert dataset to an EvaluationDataset for model evaluation.""" return EvaluationDataset( data=self._df.to_pandas(), targets=self._targets, path=path, feature_names=feature_names, predictions=self._predictions, )
[docs]def from_polars( df: pl.DataFrame, source: Union[str, DatasetSource, None] = None, targets: Optional[str] = None, name: Optional[str] = None, digest: Optional[str] = None, predictions: Optional[str] = None, ) -> PolarsDataset: """Construct a :py:class:`PolarsDataset <mlflow.data.polars_dataset.PolarsDataset>` instance. Args: df: A polars DataFrame. source: Source from which the DataFrame was derived, e.g. a filesystem path, an S3 URI, an HTTPS URL, a delta table name with version, or spark table etc. ``source`` may be specified as a URI, a path-like string, or an instance of :py:class:`DatasetSource <mlflow.data.dataset_source.DatasetSource>`. If unspecified, the source is assumed to be the code location (e.g. notebook cell, script, etc.) where :py:func:`from_polars <mlflow.data.from_polars>` is being called. targets: An optional target column name for supervised training. This column must be present in ``df``. name: Name of the dataset. If unspecified, a name is generated. digest: Dataset digest (hash). If unspecified, a digest is computed automatically. predictions: An optional predictions column name for model evaluation. This column must be present in ``df``. .. code-block:: python :test: :caption: Example import mlflow import polars as pl x = pl.DataFrame( [["tom", 10, 1, 1], ["nick", 15, 0, 1], ["julie", 14, 1, 1]], schema=["Name", "Age", "Label", "ModelOutput"], ) dataset = mlflow.data.from_polars(x, targets="Label", predictions="ModelOutput") """ from mlflow.data.code_dataset_source import CodeDatasetSource from mlflow.data.dataset_source_registry import resolve_dataset_source from mlflow.tracking.context import registry if source is not None: if isinstance(source, DatasetSource): resolved_source = source else: resolved_source = resolve_dataset_source(source) else: context_tags = registry.resolve_tags() resolved_source = CodeDatasetSource(tags=context_tags) return PolarsDataset( df=df, source=resolved_source, targets=targets, name=name, digest=digest, predictions=predictions, )