import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
import yaml
DEFAULT_API_VERSION = "1"
[docs]class ResourceType(Enum):
"""
Enum to define the different types of resources needed to serve a model.
"""
VECTOR_SEARCH_INDEX = "vector_search_index"
SERVING_ENDPOINT = "serving_endpoint"
SQL_WAREHOUSE = "sql_warehouse"
UC_FUNCTION = "uc_function"
[docs]@dataclass
class Resource(ABC):
"""
Base class for defining the resources needed to serve a model.
Args:
type (ResourceType): The resource type.
target_uri (str): The target URI where these resources are hosted.
"""
type: ResourceType
target_uri: str
[docs] @abstractmethod
def to_dict(self):
"""
Convert the resource to a dictionary.
Subclasses must implement this method.
"""
[docs] @classmethod
def from_dict(cls, data):
"""
Convert the dictionary to a Resource.
Subclasses must implement this method.
"""
@dataclass
class DatabricksResource(Resource, ABC):
"""
Base class to define all the Databricks resources to serve a model.
"""
target_uri: str = "databricks"
@dataclass
class DatabricksServingEndpoint(DatabricksResource):
"""
Define Databricks LLM endpoint resource to serve a model.
Args:
endpoint_name (str): The name of all the databricks endpoints used by the model.
"""
type: ResourceType = ResourceType.SERVING_ENDPOINT
endpoint_name: str = None
def to_dict(self):
return {self.type.value: [{"name": self.endpoint_name}]} if self.endpoint_name else {}
@classmethod
def from_dict(cls, data: Dict[str, str]):
return cls(endpoint_name=data["name"])
@dataclass
class DatabricksVectorSearchIndex(DatabricksResource):
"""
Define Databricks vector search index name resource to serve a model.
Args:
index_name (str): The name of all the databricks vector search index names
used by the model.
"""
type: ResourceType = ResourceType.VECTOR_SEARCH_INDEX
index_name: str = None
def to_dict(self):
return {self.type.value: [{"name": self.index_name}]} if self.index_name else {}
@classmethod
def from_dict(cls, data: Dict[str, str]):
return cls(index_name=data["name"])
@dataclass
class DatabricksSQLWarehouse(DatabricksResource):
"""
Define Databricks sql warehouse resource to serve a model.
Args:
warehouse_id (str): The id of the sql warehouse used by the model
"""
type: ResourceType = ResourceType.SQL_WAREHOUSE
warehouse_id: str = None
def to_dict(self):
return {self.type.value: [{"name": self.warehouse_id}]} if self.warehouse_id else {}
@classmethod
def from_dict(cls, data: Dict[str, str]):
return cls(warehouse_id=data["name"])
@dataclass
class DatabricksUCFunction(DatabricksResource):
"""
Define Databricks UC Function to serve a model.
Args:
function_name (str): The name of the function used by the model
"""
type: ResourceType = ResourceType.UC_FUNCTION
function_name: str = None
def to_dict(self):
return {self.type.value: [{"name": self.function_name}]} if self.function_name else {}
@classmethod
def from_dict(cls, data: Dict[str, str]):
return cls(function_name=data["name"])
def _get_resource_class_by_type(target_uri: str, resource_type: ResourceType):
resource_classes = {
"databricks": {
ResourceType.SERVING_ENDPOINT.value: DatabricksServingEndpoint,
ResourceType.VECTOR_SEARCH_INDEX.value: DatabricksVectorSearchIndex,
ResourceType.SQL_WAREHOUSE.value: DatabricksSQLWarehouse,
ResourceType.UC_FUNCTION.value: DatabricksUCFunction,
}
}
resource = resource_classes.get(target_uri)
if resource is None:
raise ValueError(f"Unsupported target URI: {target_uri}")
return resource.get(resource_type)
class _ResourceBuilder:
"""
Private builder class to build the resources dictionary.
"""
@staticmethod
def from_resources(
resources: List[Resource], api_version: str = DEFAULT_API_VERSION
) -> Dict[str, Dict[ResourceType, List[Dict]]]:
resource_dict = {}
for resource in resources:
resource_data = resource.to_dict()
for resource_type, values in resource_data.items():
target_dict = resource_dict.setdefault(resource.target_uri, {})
target_list = target_dict.setdefault(resource_type, [])
target_list.extend(values)
resource_dict["api_version"] = api_version
return resource_dict
@staticmethod
def from_dict(data) -> Dict[str, Dict[ResourceType, List[Dict]]]:
resources = []
api_version = data.pop("api_version")
if api_version == "1":
for target_uri, config in data.items():
for resource_type, values in config.items():
resource_class = _get_resource_class_by_type(target_uri, resource_type)
if resource_class:
resources.extend(resource_class.from_dict(value) for value in values)
else:
raise ValueError(f"Unsupported resource type: {resource_type}")
else:
raise ValueError(f"Unsupported API version: {api_version}")
return _ResourceBuilder.from_resources(resources, api_version)
@staticmethod
def from_yaml_file(path: str) -> Dict[str, Dict[ResourceType, List[Dict]]]:
if not os.path.exists(path):
raise OSError(f"No such file or directory: '{path}'")
path = os.path.abspath(path)
with open(path) as file:
data = yaml.safe_load(file)
return _ResourceBuilder.from_dict(data)