Source code for mlflow.models.dependencies_schemas

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional

from mlflow.utils.annotations import experimental

_RETRIEVER_PRIMARY_KEY = "__retriever_primary_key__"
_RETRIEVER_TEXT_COLUMN = "__retriever_text_column__"
_RETRIEVER_DOC_URI = "__retriever_doc_uri__"
_RETRIEVER_OTHER_COLUMNS = "__retriever_other_columns__"
_RETRIEVER_NAME = "__retriever_name__"


class DependenciesSchemasType(Enum):
    """
    Enum to define the different types of dependencies schemas for the model.
    """

    RETRIEVERS = "retrievers"


[docs]@experimental def set_retriever_schema( *, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[List[str]] = None, name: Optional[str] = "retriever", ): """ After defining your vector store in a Python file or notebook, call set_retriever_schema() so that, when MLflow retrieves documents during model inference, MLflow can interpret the fields in each retrieved document and determine which fields correspond to the document text, document URI, etc. Args: primary_key: The primary key of the retriever or vector index. text_column: The name of the text column to use for the embeddings. doc_uri: The name of the column that contains the document URI. other_columns: A list of other columns that are part of the vector index that need to be retrieved during trace logging. name: The name of the retriever or vector store. .. code-block:: Python :caption: Example from mlflow.models import set_retriever_schema set_retriever_schema( primary_key="chunk_id", text_column="chunk_text", doc_uri="doc_uri", other_columns=["title"], ) """ globals()[_RETRIEVER_PRIMARY_KEY] = primary_key globals()[_RETRIEVER_TEXT_COLUMN] = text_column globals()[_RETRIEVER_DOC_URI] = doc_uri globals()[_RETRIEVER_OTHER_COLUMNS] = other_columns or [] globals()[_RETRIEVER_NAME] = name
def _get_retriever_schema(): """ Get the vector search schema defined by the user. Returns: VectorSearchIndex: The vector search index schema. """ if not globals().get(_RETRIEVER_PRIMARY_KEY, None) or not globals().get( _RETRIEVER_TEXT_COLUMN, None ): return [] return [ RetrieverSchema( name=globals().get(_RETRIEVER_NAME, None), primary_key=globals().get(_RETRIEVER_PRIMARY_KEY, None), text_column=globals().get(_RETRIEVER_TEXT_COLUMN, None), doc_uri=globals().get(_RETRIEVER_DOC_URI, None), other_columns=globals().get(_RETRIEVER_OTHER_COLUMNS, None), ) ] def _clear_retriever_schema(): """ Clear the vector search schema defined by the user. """ globals().pop(_RETRIEVER_PRIMARY_KEY, None) globals().pop(_RETRIEVER_TEXT_COLUMN, None) globals().pop(_RETRIEVER_DOC_URI, None) globals().pop(_RETRIEVER_OTHER_COLUMNS, None) globals().pop(_RETRIEVER_NAME, None) def _clear_dependencies_schemas(): """ Clear all the dependencies schema defined by the user. """ # Clear the vector search schema _clear_retriever_schema() @contextmanager def _get_dependencies_schemas(): dependencies_schemas = DependenciesSchemas(retriever_schemas=_get_retriever_schema()) try: yield dependencies_schemas finally: _clear_dependencies_schemas() @dataclass class Schema(ABC): """ Base class for defining the resources needed to serve a model. Args: type (ResourceType): The type of the schema. """ type: DependenciesSchemasType @abstractmethod def to_dict(self): """ Convert the resource to a dictionary. Subclasses must implement this method. """ @classmethod @abstractmethod def from_dict(cls, data: Dict[str, str]): """ Convert the dictionary to a Resource. Subclasses must implement this method. """ @dataclass class RetrieverSchema(Schema): """ Define vector search index resource to serve a model. Args: name (str): The name of the vector search index schema. primary_key (str): The primary key for the index. text_column (str): The main text column for the index. doc_uri (Optional[str]): The document URI for the index. other_columns (Optional[List[str]]): Additional columns in the index. """ def __init__( self, name: str, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[List[str]] = None, ): super().__init__(type=DependenciesSchemasType.RETRIEVERS) self.name = name self.primary_key = primary_key self.text_column = text_column self.doc_uri = doc_uri self.other_columns = other_columns or [] def to_dict(self): return { self.type.value: [ { "name": self.name, "primary_key": self.primary_key, "text_column": self.text_column, "doc_uri": self.doc_uri, "other_columns": self.other_columns, } ] } @classmethod def from_dict(cls, data: Dict[str, str]): return cls( name=data["name"], primary_key=data["primary_key"], text_column=data["text_column"], doc_uri=data.get("doc_uri"), other_columns=data.get("other_columns", []), ) @dataclass class DependenciesSchemas: retriever_schemas: List[RetrieverSchema] = field(default_factory=list) def to_dict(self) -> Dict[str, Dict[DependenciesSchemasType, List[Dict]]]: if not self.retriever_schemas: return None return { "dependencies_schemas": { DependenciesSchemasType.RETRIEVERS.value: [ index.to_dict()[DependenciesSchemasType.RETRIEVERS.value][0] for index in self.retriever_schemas ], } }