Source code for mlflow.gateway.config

import json
import logging
import os
import pathlib
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import pydantic
import yaml
from packaging import version
from packaging.version import Version
from pydantic import ConfigDict, Field, ValidationError, root_validator, validator
from pydantic.json import pydantic_encoder

from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import ConfigModel, LimitModel, ResponseModel
from mlflow.gateway.constants import (
    MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES,
    MLFLOW_GATEWAY_ROUTE_BASE,
    MLFLOW_QUERY_SUFFIX,
)
from mlflow.gateway.utils import (
    check_configuration_deprecated_fields,
    check_configuration_route_name_collisions,
    is_valid_ai21labs_model,
    is_valid_endpoint_name,
    is_valid_mosiacml_chat_model,
)

_logger = logging.getLogger(__name__)

IS_PYDANTIC_V2 = version.parse(pydantic.version.VERSION) >= version.parse("2.0")

if IS_PYDANTIC_V2:
    from pydantic import SerializeAsAny


[docs]class Provider(str, Enum): OPENAI = "openai" ANTHROPIC = "anthropic" COHERE = "cohere" AI21LABS = "ai21labs" MLFLOW_MODEL_SERVING = "mlflow-model-serving" MOSAICML = "mosaicml" HUGGINGFACE_TEXT_GENERATION_INFERENCE = "huggingface-text-generation-inference" PALM = "palm" BEDROCK = "bedrock" AMAZON_BEDROCK = "amazon-bedrock" # an alias for bedrock # Note: The following providers are only supported on Databricks DATABRICKS_MODEL_SERVING = "databricks-model-serving" DATABRICKS = "databricks" MISTRAL = "mistral" TOGETHERAI = "togetherai"
[docs] @classmethod def values(cls): return {p.value for p in cls}
[docs]class TogetherAIConfig(ConfigModel): togetherai_api_key: str
[docs] @validator("togetherai_api_key", pre=True) def validate_togetherai_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class RouteType(str, Enum): LLM_V1_COMPLETIONS = "llm/v1/completions" LLM_V1_CHAT = "llm/v1/chat" LLM_V1_EMBEDDINGS = "llm/v1/embeddings"
[docs]class CohereConfig(ConfigModel): cohere_api_key: str
[docs] @validator("cohere_api_key", pre=True) def validate_cohere_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class AI21LabsConfig(ConfigModel): ai21labs_api_key: str
[docs] @validator("ai21labs_api_key", pre=True) def validate_ai21labs_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class MosaicMLConfig(ConfigModel): mosaicml_api_key: str mosaicml_api_base: Optional[str] = None
[docs] @validator("mosaicml_api_key", pre=True) def validate_mosaicml_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class OpenAIAPIType(str, Enum): OPENAI = "openai" AZURE = "azure" AZUREAD = "azuread" @classmethod def _missing_(cls, value): """ Implements case-insensitive matching of API type strings """ for api_type in cls: if api_type.value == value.lower(): return api_type raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{value}'")
[docs]class OpenAIConfig(ConfigModel): openai_api_key: str openai_api_type: OpenAIAPIType = OpenAIAPIType.OPENAI openai_api_base: Optional[str] = None openai_api_version: Optional[str] = None openai_deployment_name: Optional[str] = None openai_organization: Optional[str] = None
[docs] @validator("openai_api_key", pre=True) def validate_openai_api_key(cls, value): return _resolve_api_key_from_input(value)
@classmethod def _validate_field_compatibility(cls, info: Dict[str, Any]): if not isinstance(info, dict): return info api_type = (info.get("openai_api_type") or OpenAIAPIType.OPENAI).lower() if api_type == OpenAIAPIType.OPENAI: if info.get("openai_deployment_name") is not None: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration can only specify a value for " f"'openai_deployment_name' if 'openai_api_type' is '{OpenAIAPIType.AZURE}' " f"or '{OpenAIAPIType.AZUREAD}'. Found type: '{api_type}'" ) if info.get("openai_api_base") is None: info["openai_api_base"] = "https://api.openai.com/v1" elif api_type in (OpenAIAPIType.AZURE, OpenAIAPIType.AZUREAD): if info.get("openai_organization") is not None: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration can only specify a value for " f"'openai_organization' if 'openai_api_type' is '{OpenAIAPIType.OPENAI}'" ) base_url = info.get("openai_api_base") deployment_name = info.get("openai_deployment_name") api_version = info.get("openai_api_version") if (base_url, deployment_name, api_version).count(None) > 0: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration must specify 'openai_api_base', " f"'openai_deployment_name', and 'openai_api_version' if 'openai_api_type' is " f"'{OpenAIAPIType.AZURE}' or '{OpenAIAPIType.AZUREAD}'." ) else: raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{api_type}'") return info if IS_PYDANTIC_V2: from pydantic import model_validator as _model_validator @_model_validator(mode="before") def validate_field_compatibility(cls, info: Dict[str, Any]): return cls._validate_field_compatibility(info) else: from pydantic import root_validator as _root_validator
[docs] @_root_validator(pre=False) def validate_field_compatibility(cls, config: Dict[str, Any]): return cls._validate_field_compatibility(config)
[docs]class AnthropicConfig(ConfigModel): anthropic_api_key: str anthropic_version: str = "2023-06-01"
[docs] @validator("anthropic_api_key", pre=True) def validate_anthropic_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class PaLMConfig(ConfigModel): palm_api_key: str
[docs] @validator("palm_api_key", pre=True) def validate_palm_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class MlflowModelServingConfig(ConfigModel): model_server_url: str # Workaround to suppress warning that Pydantic raises when a field name starts with "model_". # https://github.com/mlflow/mlflow/issues/10335 model_config = pydantic.ConfigDict(protected_namespaces=())
[docs]class HuggingFaceTextGenerationInferenceConfig(ConfigModel): hf_server_url: str
[docs]class AWSBaseConfig(pydantic.BaseModel): aws_region: Optional[str] = None
[docs]class AWSRole(AWSBaseConfig): aws_role_arn: str session_length_seconds: int = 15 * 60
[docs]class AWSIdAndKey(AWSBaseConfig): aws_access_key_id: str aws_secret_access_key: str aws_session_token: Optional[str] = None
[docs]class AmazonBedrockConfig(ConfigModel): # order here is important, at least for pydantic<2 aws_config: Union[AWSRole, AWSIdAndKey, AWSBaseConfig]
[docs]class MistralConfig(ConfigModel): mistral_api_key: str
[docs] @validator("mistral_api_key", pre=True) def validate_mistral_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class ModelInfo(ResponseModel): name: Optional[str] = None provider: Provider
def _resolve_api_key_from_input(api_key_input): """ Resolves the provided API key. Input formats accepted: - Path to a file as a string which will have the key loaded from it - environment variable name that stores the api key - the api key itself """ if not isinstance(api_key_input, str): raise MlflowException.invalid_parameter_value( "The api key provided is not a string. Please provide either an environment " "variable key, a path to a file containing the api key, or the api key itself" ) # try reading as an environment variable if api_key_input.startswith("$"): env_var_name = api_key_input[1:] if env_var := os.getenv(env_var_name): return env_var else: raise MlflowException.invalid_parameter_value( f"Environment variable {env_var_name!r} is not set" ) # try reading from a local path file = pathlib.Path(api_key_input) if file.is_file(): return file.read_text() # if the key itself is passed, return return api_key_input
[docs]class Model(ConfigModel): name: Optional[str] = None provider: Union[str, Provider] if IS_PYDANTIC_V2: config: Optional[SerializeAsAny[ConfigModel]] = None else: config: Optional[ConfigModel] = None
[docs] @validator("provider", pre=True) def validate_provider(cls, value): from mlflow.gateway.provider_registry import provider_registry if isinstance(value, Provider): return value formatted_value = value.replace("-", "_").upper() if formatted_value in Provider.__members__: return Provider[formatted_value] if value in provider_registry.keys(): return value raise MlflowException.invalid_parameter_value(f"The provider '{value}' is not supported.")
@classmethod def _validate_config(cls, info, values): from mlflow.gateway.provider_registry import provider_registry if provider := values.get("provider"): config_type = provider_registry.get(provider).CONFIG_TYPE return config_type(**info) raise MlflowException.invalid_parameter_value( "A provider must be provided for each gateway route." ) if IS_PYDANTIC_V2: @validator("config", pre=True) def validate_config(cls, info, values): return cls._validate_config(info, values) else:
[docs] @validator("config", pre=True) def validate_config(cls, config, values): return cls._validate_config(config, values)
[docs]class AliasedConfigModel(ConfigModel): """ Enables use of field aliases in a configuration model for backwards compatibility """ if Version(pydantic.__version__) >= Version("2.0"): model_config = ConfigDict(populate_by_name=True) else: class Config: allow_population_by_field_name = True
[docs]class Limit(LimitModel): calls: int key: Optional[str] = None renewal_period: str
[docs]class LimitsConfig(ConfigModel): limits: Optional[List[Limit]] = []
[docs]class RouteConfig(AliasedConfigModel): name: str route_type: RouteType = Field(alias="endpoint_type") model: Model limit: Optional[Limit] = None
[docs] @validator("name") def validate_endpoint_name(cls, route_name): if not is_valid_endpoint_name(route_name): raise MlflowException.invalid_parameter_value( "The route name provided contains disallowed characters for a url endpoint. " f"'{route_name}' is invalid. Names cannot contain spaces or any non " "alphanumeric characters other than hyphen and underscore." ) return route_name
[docs] @validator("model", pre=True) def validate_model(cls, model): if model: model_instance = Model(**model) if model_instance.provider in Provider.values() and model_instance.config is None: raise MlflowException.invalid_parameter_value( "A config must be supplied when setting a provider. The provider entry for " f"{model_instance.provider} is incorrect." ) return model
[docs] @root_validator(skip_on_failure=True) def validate_route_type_and_model_name(cls, values): route_type = values.get("route_type") model = values.get("model") if ( model and model.provider == "mosaicml" and route_type == RouteType.LLM_V1_CHAT and not is_valid_mosiacml_chat_model(model.name) ): raise MlflowException.invalid_parameter_value( f"An invalid model has been specified for the chat route. '{model.name}'. " f"Ensure the model selected starts with one of: " f"{MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES}" ) if model and model.provider == "ai21labs" and not is_valid_ai21labs_model(model.name): raise MlflowException.invalid_parameter_value( f"An Unsupported AI21Labs model has been specified: '{model.name}'. " f"Please see documentation for supported models." ) return values
[docs] @validator("route_type", pre=True) def validate_route_type(cls, value): if value in RouteType._value2member_map_: return value raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")
[docs] @validator("limit", pre=True) def validate_limit(cls, value): from limits import parse if value: limit = Limit(**value) try: parse(f"{limit.calls}/{limit.renewal_period}") except ValueError: raise MlflowException.invalid_parameter_value( "Failed to parse the rate limit configuration." "Please make sure limit.calls is a positive number and" "limit.renewal_period is a right granularity" ) return value
[docs] def to_route(self) -> "Route": return Route( name=self.name, route_type=self.route_type, model=RouteModelInfo( name=self.model.name, provider=self.model.provider, ), route_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}", limit=self.limit, )
[docs]class RouteModelInfo(ResponseModel): name: Optional[str] = None # Use `str` instead of `Provider` enum to allow gateway backends such as Databricks to # support new providers without breaking the gateway client. provider: str
_ROUTE_EXTRA_SCHEMA = { "example": { "name": "openai-completions", "route_type": "llm/v1/completions", "model": { "name": "gpt-4o-mini", "provider": "openai", }, "route_url": "/gateway/routes/completions/invocations", } }
[docs]class Route(ConfigModel): name: str route_type: str model: RouteModelInfo route_url: str limit: Optional[Limit] = None
[docs] class Config: if IS_PYDANTIC_V2: json_schema_extra = _ROUTE_EXTRA_SCHEMA else: schema_extra = _ROUTE_EXTRA_SCHEMA
[docs] def to_endpoint(self): from mlflow.deployments.server.config import Endpoint return Endpoint( name=self.name, endpoint_type=self.route_type, model=self.model, endpoint_url=self.route_url, limit=self.limit, )
[docs]class GatewayConfig(AliasedConfigModel): routes: List[RouteConfig] = Field(alias="endpoints")
def _load_route_config(path: Union[str, Path]) -> GatewayConfig: """ Reads the gateway configuration yaml file from the storage location and returns an instance of the configuration RouteConfig class """ if isinstance(path, str): path = Path(path) try: configuration = yaml.safe_load(path.read_text()) except Exception as e: raise MlflowException.invalid_parameter_value( f"The file at {path} is not a valid yaml file" ) from e check_configuration_deprecated_fields(configuration) check_configuration_route_name_collisions(configuration) try: return GatewayConfig(**configuration) except ValidationError as e: raise MlflowException.invalid_parameter_value( f"The gateway configuration is invalid: {e}" ) from e def _save_route_config(config: GatewayConfig, path: Union[str, Path]) -> None: if isinstance(path, str): path = Path(path) path.write_text(yaml.safe_dump(json.loads(json.dumps(config.dict(), default=pydantic_encoder)))) def _validate_config(config_path: str) -> GatewayConfig: if not os.path.exists(config_path): raise MlflowException.invalid_parameter_value(f"{config_path} does not exist") try: return _load_route_config(config_path) except ValidationError as e: raise MlflowException.invalid_parameter_value(f"Invalid gateway configuration: {e}") from e