Source code for mlflow.data.delta_dataset_source

import logging
from typing import Any, Dict, Optional

from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_managed_catalog_messages_pb2 import (
    GetTable,
    GetTableResponse,
)
from mlflow.protos.databricks_managed_catalog_service_pb2 import DatabricksUnityCatalogService
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils._spark_utils import _get_active_spark_session
from mlflow.utils._unity_catalog_utils import get_full_name_from_sc
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.proto_json_utils import message_to_json
from mlflow.utils.rest_utils import (
    _REST_API_PATH_PREFIX,
    call_endpoint,
    extract_api_info_for_service,
)

DATABRICKS_HIVE_METASTORE_NAME = "hive_metastore"
# these two catalog names both points to the workspace local default HMS (hive metastore).
DATABRICKS_LOCAL_METASTORE_NAMES = [DATABRICKS_HIVE_METASTORE_NAME, "spark_catalog"]
# samples catalog is managed by databricks for hosting public dataset like NYC taxi dataset.
# it is neither a UC nor local metastore catalog
DATABRICKS_SAMPLES_CATALOG_NAME = "samples"

_logger = logging.getLogger(__name__)


[docs]class DeltaDatasetSource(DatasetSource): """ Represents the source of a dataset stored at in a delta table. """ def __init__( self, path: Optional[str] = None, delta_table_name: Optional[str] = None, delta_table_version: Optional[int] = None, delta_table_id: Optional[str] = None, ): if (path, delta_table_name).count(None) != 1: raise MlflowException( 'Must specify exactly one of "path" or "table_name"', INVALID_PARAMETER_VALUE, ) self._path = path if delta_table_name is not None: self._delta_table_name = get_full_name_from_sc( delta_table_name, _get_active_spark_session() ) else: self._delta_table_name = delta_table_name self._delta_table_version = delta_table_version self._delta_table_id = delta_table_id @staticmethod def _get_source_type() -> str: return "delta_table"
[docs] def load(self, **kwargs): """ Loads the dataset source as a Delta Dataset Source. Returns: An instance of ``pyspark.sql.DataFrame``. """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() spark_read_op = spark.read.format("delta") if self._delta_table_version is not None: spark_read_op = spark_read_op.option("versionAsOf", self._delta_table_version) if self._path: return spark_read_op.load(self._path) else: return spark_read_op.table(self._delta_table_name)
@property def path(self) -> Optional[str]: return self._path @property def delta_table_name(self) -> Optional[str]: return self._delta_table_name @property def delta_table_id(self) -> Optional[str]: return self._delta_table_id @property def delta_table_version(self) -> Optional[int]: return self._delta_table_version @staticmethod def _can_resolve(raw_source: Any): return False @classmethod def _resolve(cls, raw_source: str) -> "DeltaDatasetSource": raise NotImplementedError # check if table is in the Databricks Unity Catalog def _is_databricks_uc_table(self): if self._delta_table_name is not None: catalog_name = self._delta_table_name.split(".", 1)[0] return ( catalog_name not in DATABRICKS_LOCAL_METASTORE_NAMES and catalog_name != DATABRICKS_SAMPLES_CATALOG_NAME ) else: return False def _lookup_table_id(self, table_name): try: req_body = message_to_json(GetTable(full_name_arg=table_name)) _METHOD_TO_INFO = extract_api_info_for_service( DatabricksUnityCatalogService, _REST_API_PATH_PREFIX ) db_creds = get_databricks_host_creds() endpoint, method = _METHOD_TO_INFO[GetTable] # We need to replace the full_name_arg in the endpoint definition with # the actual table name for the REST API to work. final_endpoint = endpoint.replace("{full_name_arg}", table_name) resp = call_endpoint( host_creds=db_creds, endpoint=final_endpoint, method=method, json_body=req_body, response_proto=GetTableResponse, ) return resp.table_id except Exception: return None
[docs] def to_dict(self) -> Dict[Any, Any]: info = {} if self._path: info["path"] = self._path if self._delta_table_name: info["delta_table_name"] = self._delta_table_name if self._delta_table_version: info["delta_table_version"] = self._delta_table_version if self._is_databricks_uc_table(): info["is_databricks_uc_table"] = True if self._delta_table_id: info["delta_table_id"] = self._delta_table_id else: info["delta_table_id"] = self._lookup_table_id(self._delta_table_name) return info
[docs] @classmethod def from_dict(cls, source_dict: Dict[Any, Any]) -> "DeltaDatasetSource": return cls( path=source_dict.get("path"), delta_table_name=source_dict.get("delta_table_name"), delta_table_version=source_dict.get("delta_table_version"), delta_table_id=source_dict.get("delta_table_id"), )