import json
import logging
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from packaging.version import Version

from import Dataset
from import DatasetSource
from import DeltaDatasetSource
from import get_normalized_md5_digest
from import EvaluationDataset
from import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from import SparkDatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
from mlflow.types import Schema
from mlflow.types.utils import _infer_schema

    import pyspark

_logger = logging.getLogger(__name__)

[docs]class SparkDataset(Dataset, PyFuncConvertibleDatasetMixin): """ Represents a Spark dataset (e.g. data derived from a Spark Table / file directory or Delta Table) for use with MLflow Tracking. """ def __init__( self, df: "pyspark.sql.DataFrame", source: DatasetSource, targets: Optional[str] = None, name: Optional[str] = None, digest: Optional[str] = None, predictions: Optional[str] = None, ): if targets is not None and targets not in df.columns: raise MlflowException( f"The specified Spark dataset does not contain the specified targets column" f" '{targets}'.", INVALID_PARAMETER_VALUE, ) if predictions is not None and predictions not in df.columns: raise MlflowException( f"The specified Spark dataset does not contain the specified predictions column" f" '{predictions}'.", INVALID_PARAMETER_VALUE, ) self._df = df self._targets = targets self._predictions = predictions super().__init__(source=source, name=name, digest=digest) def _compute_digest(self) -> str: """ Computes a digest for the dataset. Called if the user doesn't supply a digest when constructing the dataset. """ # Retrieve a semantic hash of the DataFrame's logical plan, which is much more efficient # and deterministic than hashing DataFrame records import numpy as np import pyspark # Spark 3.1.0+ has a semanticHash() method on DataFrame if Version(pyspark.__version__) >= Version("3.1.0"): semantic_hash = self._df.semanticHash() else: semantic_hash = self._df._jdf.queryExecution().analyzed().semanticHash() return get_normalized_md5_digest([np.int64(semantic_hash)])
[docs] def to_dict(self) -> Dict[str, str]: """Create config dictionary for the dataset. Returns a string dictionary containing the following fields: name, digest, source, source type, schema, and profile. """ schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None config = super().to_dict() config.update( { "schema": schema, "profile": json.dumps(self.profile), } ) return config
@property def df(self): """The Spark DataFrame instance. Returns: The Spark DataFrame instance. """ return self._df @property def targets(self) -> Optional[str]: """The name of the Spark DataFrame column containing targets (labels) for supervised learning. Returns: The string name of the Spark DataFrame column containing targets. """ return self._targets @property def predictions(self) -> Optional[str]: """ The name of the predictions column. May be ``None`` if no predictions column was specified when the dataset was created. """ return self._predictions @property def source(self) -> Union[SparkDatasetSource, DeltaDatasetSource]: """ Spark dataset source information. Returns: An instance of :py:class:`SparkDatasetSource <>` or :py:class:`DeltaDatasetSource <>`. """ return self._source @property def profile(self) -> Optional[Any]: """ A profile of the dataset. May be None if no profile is available. """ try: from pyspark.rdd import BoundedFloat # Use Spark RDD countApprox to get approximate count since count() may be expensive. # Note that we call the Scala RDD API because the PySpark API does not respect the # specified timeout. Reference code: # # #RDD.countApprox. This is confirmed to work in all Spark 3.x versions py_rdd = self.df.rdd drdd = py_rdd.mapPartitions(lambda it: [float(sum(1 for i in it))]) jrdd = drdd.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() jdrdd = drdd.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) timeout_millis = 5000 confidence = 0.9 approx_count_operation = jdrdd.sumApprox(timeout_millis, confidence) approx_count_result = approx_count_operation.initialValue() approx_count_float = BoundedFloat( mean=approx_count_result.mean(), confidence=approx_count_result.confidence(), low=approx_count_result.low(), high=approx_count_result.high(), ) approx_count = int(approx_count_float) if approx_count <= 0: # An approximate count of zero likely indicates that the count timed # out before an estimate could be made. In this case, we use the value # "unknown" so that users don't think the dataset is empty approx_count = "unknown" return { "approx_count": approx_count, } except Exception as e: _logger.warning( "Encountered an unexpected exception while computing Spark dataset profile." " Exception: %s", e, ) @cached_property def schema(self) -> Optional[Schema]: """ The MLflow ColSpec schema of the Spark dataset. """ try: return _infer_schema(self._df) except Exception as e: _logger.warning("Failed to infer schema for Spark dataset. Exception: %s", e) return None def to_pyfunc(self) -> PyFuncInputsOutputs: """ Converts the Spark DataFrame to pandas and splits the resulting :py:class:`pandas.DataFrame` into: 1. a :py:class:`pandas.DataFrame` of features and 2. a :py:class:`pandas.Series` of targets. To avoid overuse of driver memory, only the first 10,000 DataFrame rows are selected. """ df = self._df.limit(10000).toPandas() if self._targets is not None: if self._targets not in df.columns: raise MlflowException( f"Failed to convert Spark dataset to pyfunc inputs and outputs because" f" the pandas representation of the Spark dataset does not contain the" f" specified targets column '{self._targets}'.", # This is an internal error because we should have validated the presence of # the target column in the Hugging Face dataset at construction time INTERNAL_ERROR, ) inputs = df.drop(columns=self._targets) outputs = df[self._targets] return PyFuncInputsOutputs(inputs=inputs, outputs=outputs) else: return PyFuncInputsOutputs(inputs=df, outputs=None) def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset: """ Converts the dataset to an EvaluationDataset for model evaluation. Required for use with mlflow.evaluate(). """ return EvaluationDataset( data=self._df.limit(10000).toPandas(), targets=self._targets, path=path, feature_names=feature_names, predictions=self._predictions, )
[docs]def load_delta( path: Optional[str] = None, table_name: Optional[str] = None, version: Optional[str] = None, targets: Optional[str] = None, name: Optional[str] = None, digest: Optional[str] = None, ) -> SparkDataset: """ Loads a :py:class:`SparkDataset <>` from a Delta table for use with MLflow Tracking. Args: path: The path to the Delta table. Either ``path`` or ``table_name`` must be specified. table_name: The name of the Delta table. Either ``path`` or ``table_name`` must be specified. version: The Delta table version. If not specified, the version will be inferred. targets: Optional. The name of the Delta table column containing targets (labels) for supervised learning. name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically generated. digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest is automatically computed. Returns: An instance of :py:class:`SparkDataset <>`. """ from import ( _try_get_delta_table_latest_version_from_path, _try_get_delta_table_latest_version_from_table_name, ) if (path, table_name).count(None) != 1: raise MlflowException( "Must specify exactly one of `table_name` or `path`.", INVALID_PARAMETER_VALUE, ) if version is None: if path is not None: version = _try_get_delta_table_latest_version_from_path(path) else: version = _try_get_delta_table_latest_version_from_table_name(table_name) if name is None and table_name is not None: name = table_name + (f"@v{version}" if version is not None else "") source = DeltaDatasetSource(path=path, delta_table_name=table_name, delta_table_version=version) df = source.load() return SparkDataset( df=df, source=source, targets=targets, name=name, digest=digest, )
[docs]def from_spark( df: "pyspark.sql.DataFrame", path: Optional[str] = None, table_name: Optional[str] = None, version: Optional[str] = None, sql: Optional[str] = None, targets: Optional[str] = None, name: Optional[str] = None, digest: Optional[str] = None, predictions: Optional[str] = None, ) -> SparkDataset: """ Given a Spark DataFrame, constructs a :py:class:`SparkDataset <>` object for use with MLflow Tracking. Args: df: The Spark DataFrame from which to construct a SparkDataset. path: The path of the Spark or Delta source that the DataFrame originally came from. Note that the path does not have to match the DataFrame exactly, since the DataFrame may have been modified by Spark operations. This is used to reload the dataset upon request via :py:func:`SparkDataset.source.load() <>`. If none of ``path``, ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source information from the run context. table_name: The name of the Spark or Delta table that the DataFrame originally came from. Note that the table does not have to match the DataFrame exactly, since the DataFrame may have been modified by Spark operations. This is used to reload the dataset upon request via :py:func:`SparkDataset.source.load() <>`. If none of ``path``, ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source information from the run context. version: If the DataFrame originally came from a Delta table, specifies the version of the Delta table. This is used to reload the dataset upon request via :py:func:`SparkDataset.source.load() <>`. ``version`` cannot be specified if ``sql`` is specified. sql: The Spark SQL statement that was originally used to construct the DataFrame. Note that the Spark SQL statement does not have to match the DataFrame exactly, since the DataFrame may have been modified by Spark operations. This is used to reload the dataset upon request via :py:func:`SparkDataset.source.load() <>`. If none of ``path``, ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source information from the run context. targets: Optional. The name of the Data Frame column containing targets (labels) for supervised learning. name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically generated. digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest is automatically computed. predictions: Optional. The name of the column containing model predictions, if the dataset contains model predictions. If specified, this column must be present in the dataframe (``df``). Returns: An instance of :py:class:`SparkDataset <>`. """ from import CodeDatasetSource from import ( _is_delta_table, _is_delta_table_path, _try_get_delta_table_latest_version_from_path, _try_get_delta_table_latest_version_from_table_name, ) from mlflow.tracking.context import registry if (path, table_name, sql).count(None) < 2: raise MlflowException( "Must specify at most one of `path`, `table_name`, or `sql`.", INVALID_PARAMETER_VALUE, ) if (sql, version).count(None) == 0: raise MlflowException( "`version` may not be specified when `sql` is specified. `version` may only be" " specified when `table_name` or `path` is specified.", INVALID_PARAMETER_VALUE, ) if sql is not None: source = SparkDatasetSource(sql=sql) elif path is not None: if _is_delta_table_path(path): version = version or _try_get_delta_table_latest_version_from_path(path) source = DeltaDatasetSource(path=path, delta_table_version=version) elif version is None: source = SparkDatasetSource(path=path) else: raise MlflowException( f"Version '{version}' was specified, but the path '{path}' does not refer" f" to a Delta table.", INVALID_PARAMETER_VALUE, ) elif table_name is not None: if _is_delta_table(table_name): version = version or _try_get_delta_table_latest_version_from_table_name(table_name) source = DeltaDatasetSource( delta_table_name=table_name, delta_table_version=version, ) elif version is None: source = SparkDatasetSource(table_name=table_name) else: raise MlflowException( f"Version '{version}' was specified, but could not find a Delta table with name" f" '{table_name}'.", INVALID_PARAMETER_VALUE, ) else: context_tags = registry.resolve_tags() source = CodeDatasetSource(tags=context_tags) return SparkDataset( df=df, source=source, targets=targets, name=name, digest=digest, predictions=predictions, )