"""
The ``mlflow.openai`` module provides an API for logging and loading OpenAI models.
Credential management for OpenAI on Databricks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. warning::
Specifying secrets for model serving with ``MLFLOW_OPENAI_SECRET_SCOPE`` is deprecated.
Use `secrets-based environment variables <https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html>`_
instead.
When this flavor logs a model on Databricks, it saves a YAML file with the following contents as
``openai.yaml`` if the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable is set.
.. code-block:: yaml
OPENAI_API_BASE: {scope}:openai_api_base
OPENAI_API_KEY: {scope}:openai_api_key
OPENAI_API_KEY_PATH: {scope}:openai_api_key_path
OPENAI_API_TYPE: {scope}:openai_api_type
OPENAI_ORGANIZATION: {scope}:openai_organization
- ``{scope}`` is the value of the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable.
- The keys are the environment variables that the ``openai-python`` package uses to
configure the API client.
- The values are the references to the secrets that store the values of the environment
variables.
When the logged model is served on Databricks, each secret will be resolved and set as the
corresponding environment variable. See https://docs.databricks.com/security/secrets/index.html
for how to set up secrets on Databricks.
"""
import itertools
import logging
import os
import warnings
from string import Formatter
from typing import Any, Dict, Optional, Set
import yaml
from packaging.version import Version
import mlflow
from mlflow import pyfunc
from mlflow.environment_variables import MLFLOW_OPENAI_SECRET_SCOPE
from mlflow.exceptions import MlflowException
from mlflow.models import Model, ModelInputExample, ModelSignature
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import _infer_signature_from_input_example
from mlflow.models.utils import _save_example
from mlflow.openai._openai_autolog import patched_call
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types import ColSpec, Schema, TensorSpec
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
from mlflow.utils.databricks_utils import (
check_databricks_secret_scope_access,
is_in_databricks_runtime,
)
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
_CONSTRAINTS_FILE_NAME,
_PYTHON_ENV_FILE_NAME,
_REQUIREMENTS_FILE_NAME,
_mlflow_conda_env,
_process_conda_env,
_process_pip_requirements,
_PythonEnv,
_validate_env_arguments,
)
from mlflow.utils.file_utils import write_to
from mlflow.utils.model_utils import (
_add_code_from_conf_to_system_path,
_get_flavor_configuration,
_validate_and_copy_code_paths,
_validate_and_prepare_target_save_path,
)
from mlflow.utils.openai_utils import (
REQUEST_URL_CHAT,
REQUEST_URL_COMPLETIONS,
REQUEST_URL_EMBEDDINGS,
_OAITokenHolder,
_OpenAIApiConfig,
_OpenAIEnvVar,
_validate_model_params,
)
from mlflow.utils.requirements_utils import _get_pinned_requirement
FLAVOR_NAME = "openai"
MODEL_FILENAME = "model.yaml"
_PYFUNC_SUPPORTED_TASKS = ("chat.completions", "embeddings", "completions")
_logger = logging.getLogger(__name__)
[docs]@experimental
def get_default_pip_requirements():
"""
Returns:
A list of default pip requirements for MLflow Models produced by this flavor.
Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
that, at minimum, contains these requirements.
"""
return list(map(_get_pinned_requirement, ["openai", "tiktoken", "tenacity"]))
[docs]@experimental
def get_default_conda_env():
"""
Returns:
The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
def _get_obj_to_task_mapping():
if Version(_get_openai_package_version()).major < 1:
from openai import api_resources as ar
return {
ar.Audio: ar.Audio.OBJECT_NAME,
ar.ChatCompletion: ar.ChatCompletion.OBJECT_NAME,
ar.Completion: ar.Completion.OBJECT_NAME,
ar.Edit: ar.Edit.OBJECT_NAME,
ar.Deployment: ar.Deployment.OBJECT_NAME,
ar.Embedding: ar.Embedding.OBJECT_NAME,
ar.Engine: ar.Engine.OBJECT_NAME,
ar.File: ar.File.OBJECT_NAME,
ar.Image: ar.Image.OBJECT_NAME,
ar.FineTune: ar.FineTune.OBJECT_NAME,
ar.Model: ar.Model.OBJECT_NAME,
ar.Moderation: "moderations",
}
else:
from openai import resources as r
return {
r.Audio: "audio",
r.chat.Completions: "chat.completions",
r.Completions: "completions",
r.Images.edit: "images.edit",
r.Embeddings: "embeddings",
r.Files: "files",
r.Images: "images",
r.FineTuning: "fine_tuning",
r.Moderations: "moderations",
r.Models: "models",
}
def _get_model_name(model):
import openai
if isinstance(model, str):
return model
if Version(_get_openai_package_version()).major < 1 and isinstance(model, openai.Model):
return model.id
raise mlflow.MlflowException(
f"Unsupported model type: {type(model)}", error_code=INVALID_PARAMETER_VALUE
)
def _get_task_name(task):
mapping = _get_obj_to_task_mapping()
if isinstance(task, str):
if task not in mapping.values():
raise mlflow.MlflowException(
f"Unsupported task: {task}", error_code=INVALID_PARAMETER_VALUE
)
return task
else:
task_name = (
mapping.get(task)
or mapping.get(task.__class__)
or mapping.get(getattr(task, "__func__")) # if task is a method
)
if task_name is None:
raise mlflow.MlflowException(
f"Unsupported task object: {task}", error_code=INVALID_PARAMETER_VALUE
)
return task_name
def _get_api_config() -> _OpenAIApiConfig:
"""Gets the parameters and configuration of the OpenAI API connected to."""
import openai
api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value, openai.api_type)
api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value, openai.api_version)
api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value, None)
engine = os.getenv(_OpenAIEnvVar.OPENAI_ENGINE.value, None)
deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
if api_type in ("azure", "azure_ad", "azuread"):
batch_size = 16
max_tokens_per_minute = 60_000
else:
# The maximum batch size is 2048:
# https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
# We use a smaller batch size to be safe.
batch_size = 1024
max_tokens_per_minute = 90_000
return _OpenAIApiConfig(
api_type=api_type,
batch_size=batch_size,
max_requests_per_minute=3_500,
max_tokens_per_minute=max_tokens_per_minute,
api_base=api_base,
api_version=api_version,
engine=engine,
deployment_id=deployment_id,
)
def _get_openai_package_version():
import openai
try:
return openai.__version__
except AttributeError:
# openai < 0.27.5 doesn't have a __version__ attribute
return openai.version.VERSION
def _log_secrets_yaml(local_model_dir, scope):
with open(os.path.join(local_model_dir, "openai.yaml"), "w") as f:
yaml.safe_dump({e.value: f"{scope}:{e.secret_key}" for e in _OpenAIEnvVar}, f)
def _parse_format_fields(s) -> Set[str]:
"""Parses format fields from a given string, e.g. "Hello {name}" -> ["name"]."""
return {fn for _, fn, _, _ in Formatter().parse(s) if fn is not None}
def _get_input_schema(task, content):
if content:
formatter = _ContentFormatter(task, content)
variables = formatter.variables
if len(variables) == 1:
return Schema([ColSpec(type="string")])
elif len(variables) > 1:
return Schema([ColSpec(name=v, type="string") for v in variables])
else:
return Schema([ColSpec(type="string")])
else:
return Schema([ColSpec(type="string")])
[docs]@experimental
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def save_model(
model,
task,
path,
conda_env=None,
code_paths=None,
mlflow_model=None,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
pip_requirements=None,
extra_pip_requirements=None,
metadata=None,
example_no_conversion=None,
**kwargs,
):
"""
Save an OpenAI model to a path on the local file system.
Args:
model: The OpenAI model name.
task: The task the model is performing, e.g., ``openai.chat.completions`` or
``'chat.completions'``.
path: Local path where the model is to be saved.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
input_example: {{ input_example }}
pip_requirements: {{ pip_requirements }}
extra_pip_requirements: {{ extra_pip_requirements }}
metadata: {{ metadata }}
example_no_conversion: {{ example_no_conversion }}
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
:ref:`mlflow.openai.messages` for more details on this parameter)
or ``top_p`` value to use for chat completion.
.. code-block:: python
import mlflow
import openai
# Chat
mlflow.openai.save_model(
model="gpt-4o-mini",
task=openai.chat.completions,
messages=[{"role": "user", "content": "Tell me a joke."}],
path="model",
)
# Completions
mlflow.openai.save_model(
model="text-davinci-002",
task=openai.completions,
prompt="{text}. The general sentiment of the text is",
path="model",
)
# Embeddings
mlflow.openai.save_model(
model="text-embedding-ada-002",
task=openai.embeddings,
path="model",
)
"""
import numpy as np
_validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
path = os.path.abspath(path)
_validate_and_prepare_target_save_path(path)
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
task = _get_task_name(task)
if mlflow_model is None:
mlflow_model = Model()
if signature is not None:
if signature.params:
_validate_model_params(
task, kwargs, {p.name: p.default for p in signature.params.params}
)
elif task == "chat.completions":
messages = kwargs.get("messages", [])
if messages and not (
all(isinstance(m, dict) for m in messages) and all(map(_is_valid_message, messages))
):
raise mlflow.MlflowException.invalid_parameter_value(
"If `messages` is provided, it must be a list of dictionaries with keys "
"'role' and 'content'."
)
signature = ModelSignature(
inputs=_get_input_schema(task, messages),
outputs=Schema([ColSpec(type="string", name=None)]),
)
elif task == "completions":
prompt = kwargs.get("prompt")
signature = ModelSignature(
inputs=_get_input_schema(task, prompt),
outputs=Schema([ColSpec(type="string", name=None)]),
)
elif task == "embeddings":
signature = ModelSignature(
inputs=Schema([ColSpec(type="string", name=None)]),
outputs=Schema([TensorSpec(type=np.dtype("float64"), shape=(-1,))]),
)
saved_example = _save_example(mlflow_model, input_example, path, example_no_conversion)
if signature is None and saved_example is not None:
wrapped_model = _OpenAIWrapper(model)
signature = _infer_signature_from_input_example(saved_example, wrapped_model)
if signature is not None:
mlflow_model.signature = signature
if metadata is not None:
mlflow_model.metadata = metadata
model_data_path = os.path.join(path, MODEL_FILENAME)
model_dict = {
"model": _get_model_name(model),
"task": task,
**kwargs,
}
with open(model_data_path, "w") as f:
yaml.safe_dump(model_dict, f)
if task in _PYFUNC_SUPPORTED_TASKS:
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.openai",
data=MODEL_FILENAME,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
code=code_dir_subpath,
)
mlflow_model.add_flavor(
FLAVOR_NAME,
openai_version=_get_openai_package_version(),
data=MODEL_FILENAME,
code=code_dir_subpath,
)
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
if is_in_databricks_runtime():
if scope := MLFLOW_OPENAI_SECRET_SCOPE.get():
url = "https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html"
warnings.warn(
"Specifying secrets for model serving with `MLFLOW_OPENAI_SECRET_SCOPE` is "
f"deprecated. Use secrets-based environment variables ({url}) instead.",
FutureWarning,
)
check_databricks_secret_scope_access(scope)
_log_secrets_yaml(path, scope)
if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements()
inferred_reqs = mlflow.models.infer_pip_requirements(
path, FLAVOR_NAME, fallback=default_reqs
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
default_reqs,
pip_requirements,
extra_pip_requirements,
)
else:
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
# Save `constraints.txt` if necessary
if pip_constraints:
write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
# Save `requirements.txt`
write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
[docs]@experimental
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def log_model(
model,
task,
artifact_path,
conda_env=None,
code_paths=None,
registered_model_name=None,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
pip_requirements=None,
extra_pip_requirements=None,
metadata=None,
example_no_conversion=None,
**kwargs,
):
"""
Log an OpenAI model as an MLflow artifact for the current run.
Args:
model: The OpenAI model name or reference instance, e.g.,
``openai.Model.retrieve("gpt-4o-mini")``.
task: The task the model is performing, e.g., ``openai.chat.completions`` or
``'chat.completions'``.
artifact_path: Run-relative artifact path.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
registered_model_name: If given, create a model version under
``registered_model_name``, also creating a registered model if one
with the given name does not exist.
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
input_example: {{ input_example }}
await_registration_for: Number of seconds to wait for the model version to finish
being created and is in ``READY`` status. By default, the function
waits for five minutes. Specify 0 or None to skip waiting.
pip_requirements: {{ pip_requirements }}
extra_pip_requirements: {{ extra_pip_requirements }}
metadata: {{ metadata }}
example_no_conversion: {{ example_no_conversion }}
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
:ref:`mlflow.openai.messages` for more details on this parameter)
or ``top_p`` value to use for chat completion.
Returns:
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
import mlflow
import openai
# Chat
with mlflow.start_run():
info = mlflow.openai.log_model(
model="gpt-4o-mini",
task=openai.chat.completions,
messages=[{"role": "user", "content": "Tell me a joke about {animal}."}],
artifact_path="model",
)
model = mlflow.pyfunc.load_model(info.model_uri)
df = pd.DataFrame({"animal": ["cats", "dogs"]})
print(model.predict(df))
# Embeddings
with mlflow.start_run():
info = mlflow.openai.log_model(
model="text-embedding-ada-002",
task=openai.embeddings,
artifact_path="embeddings",
)
model = mlflow.pyfunc.load_model(info.model_uri)
print(model.predict(["hello", "world"]))
"""
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.openai,
registered_model_name=registered_model_name,
model=model,
task=task,
conda_env=conda_env,
code_paths=code_paths,
signature=signature,
input_example=input_example,
await_registration_for=await_registration_for,
pip_requirements=pip_requirements,
extra_pip_requirements=extra_pip_requirements,
metadata=metadata,
example_no_conversion=example_no_conversion,
**kwargs,
)
def _load_model(path):
with open(path) as f:
return yaml.safe_load(f)
def _is_valid_message(d):
return isinstance(d, dict) and "content" in d and "role" in d
class _ContentFormatter:
def __init__(self, task, template=None):
if task == "completions":
template = template or "{prompt}"
if not isinstance(template, str):
raise mlflow.MlflowException.invalid_parameter_value(
f"Template for task {task} expects type `str`, but got {type(template)}."
)
self.template = template
self.format_fn = self.format_prompt
self.variables = sorted(_parse_format_fields(self.template))
elif task == "chat.completions":
if not template:
template = [{"role": "user", "content": "{content}"}]
if not all(map(_is_valid_message, template)):
raise mlflow.MlflowException.invalid_parameter_value(
f"Template for task {task} expects type `dict` with keys 'content' "
f"and 'role', but got {type(template)}."
)
self.template = template.copy()
self.format_fn = self.format_chat
self.variables = sorted(
set(
itertools.chain.from_iterable(
_parse_format_fields(message.get("content"))
| _parse_format_fields(message.get("role"))
for message in self.template
)
)
)
if not self.variables:
self.template.append({"role": "user", "content": "{content}"})
self.variables.append("content")
else:
raise mlflow.MlflowException.invalid_parameter_value(
f"Task type ``{task}`` is not supported for formatting."
)
def format(self, **params):
if missing_params := set(self.variables) - set(params):
raise mlflow.MlflowException.invalid_parameter_value(
f"Expected parameters {self.variables} to be provided, "
f"only got {list(params)}, {list(missing_params)} are missing."
)
return self.format_fn(**params)
def format_prompt(self, **params):
return self.template.format(**{v: params[v] for v in self.variables})
def format_chat(self, **params):
format_args = {v: params[v] for v in self.variables}
return [
{
"role": message.get("role").format(**format_args),
"content": message.get("content").format(**format_args),
}
for message in self.template
]
def _first_string_column(pdf):
iter_str_cols = (c for c, v in pdf.iloc[0].items() if isinstance(v, str))
col = next(iter_str_cols, None)
if col is None:
raise mlflow.MlflowException.invalid_parameter_value(
f"Could not find a string column in the input data: {pdf.dtypes.to_dict()}"
)
return col
class _OpenAIWrapper:
def __init__(self, model):
task = model.pop("task")
if task not in _PYFUNC_SUPPORTED_TASKS:
raise mlflow.MlflowException.invalid_parameter_value(
f"Unsupported task: {task}. Supported tasks: {_PYFUNC_SUPPORTED_TASKS}."
)
self.model = model
self.task = task
self.api_config = _get_api_config()
self.api_token = _OAITokenHolder(self.api_config.api_type)
# If the same parameter exists in self.model & self.api_config,
# we use the parameter from self.model
self.request_configs = {}
for x in ["api_base", "api_version", "api_type", "engine", "deployment_id"]:
if x in self.model:
self.request_configs[x] = self.model.pop(x)
elif value := getattr(self.api_config, x):
self.request_configs[x] = value
if self.request_configs.get("api_type") in ("azure", "azure_ad", "azuread"):
deployment_id = self.request_configs.get("deployment_id")
if self.request_configs.get("engine"):
# Avoid using both parameters as they serve the same purpose
# Invalid inputs:
# - Wrong engine + correct/wrong deployment_id
# - No engine + wrong deployment_id
# Valid inputs:
# - Correct engine + correct/wrong deployment_id
# - No engine + correct deployment_id
if deployment_id is not None:
_logger.warning(
"Both engine and deployment_id are set. "
"Using engine as it takes precedence."
)
elif deployment_id is None:
raise MlflowException(
"Either engine or deployment_id must be set for Azure OpenAI API",
)
if self.task != "embeddings":
self._setup_completions()
def get_raw_model(self):
"""
Returns the underlying model.
"""
return self.model
def _setup_completions(self):
if self.task == "chat.completions":
self.template = self.model.get("messages", [])
else:
self.template = self.model.get("prompt")
self.formater = _ContentFormatter(self.task, self.template)
def format_completions(self, params_list):
return [self.formater.format(**params) for params in params_list]
def get_params_list(self, data):
if len(self.formater.variables) == 1:
variable = self.formater.variables[0]
if variable in data.columns:
return data[[variable]].to_dict(orient="records")
else:
first_string_column = _first_string_column(data)
return [{variable: s} for s in data[first_string_column]]
else:
return data[self.formater.variables].to_dict(orient="records")
def _construct_request_url(self, task_url, default_url):
api_type = self.request_configs.get("api_type")
api_base = self.request_configs.get("api_base")
if api_type in ("azure", "azure_ad", "azuread"):
api_version = self.request_configs.get("api_version")
deployment_id = self.request_configs.get("deployment_id")
return (
f"{api_base}/openai/deployments/{deployment_id}/"
f"{task_url}?api-version={api_version}"
)
return f"{api_base}/{task_url}" if api_base else default_url
def _predict_chat(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
messages_list = self.format_completions(self.get_params_list(data))
requests = [{**self.model, **params, "messages": messages} for messages in messages_list]
request_url = self._construct_request_url("chat/completions", REQUEST_URL_CHAT)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [r["choices"][0]["message"]["content"] for r in results]
def _predict_completions(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
prompts_list = self.format_completions(self.get_params_list(data))
batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")
requests = [
{
**self.model,
**params,
"prompt": prompts_list[i : i + batch_size],
}
for i in range(0, len(prompts_list), batch_size)
]
request_url = self._construct_request_url("completions", REQUEST_URL_COMPLETIONS)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [row["text"] for batch in results for row in batch["choices"]]
def _predict_embeddings(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")
first_string_column = _first_string_column(data)
texts = data[first_string_column].tolist()
requests = [
{
**self.model,
**params,
"input": texts[i : i + batch_size],
}
for i in range(0, len(texts), batch_size)
]
request_url = self._construct_request_url("embeddings", REQUEST_URL_EMBEDDINGS)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [row["embedding"] for batch in results for row in batch["data"]]
def predict(self, data, params: Optional[Dict[str, Any]] = None):
"""
Args:
data: Model input data.
params: Additional parameters to pass to the model for inference.
Returns:
Model predictions.
"""
self.api_token.refresh()
if self.task == "chat.completions":
return self._predict_chat(data, params or {})
elif self.task == "completions":
return self._predict_completions(data, params or {})
elif self.task == "embeddings":
return self._predict_embeddings(data, params or {})
def _load_pyfunc(path):
"""Loads PyFunc implementation. Called by ``pyfunc.load_model``.
Args:
path: Local filesystem path to the MLflow Model with the ``openai`` flavor.
"""
return _OpenAIWrapper(_load_model(path))
[docs]@experimental
def load_model(model_uri, dst_path=None):
"""
Load an OpenAI model from a local file or a run.
Args:
model_uri: The location, in URI format, of the MLflow model. For example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see
`Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
artifact-locations>`_.
dst_path: The local filesystem path to which to download the model artifact.
This directory must already exist. If unspecified, a local output
path will be created.
Returns:
A dictionary representing the OpenAI model.
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
flavor_conf = _get_flavor_configuration(local_model_path, FLAVOR_NAME)
_add_code_from_conf_to_system_path(local_model_path, flavor_conf)
model_data_path = os.path.join(local_model_path, flavor_conf.get("data", MODEL_FILENAME))
return _load_model(model_data_path)
[docs]@experimental
@autologging_integration(FLAVOR_NAME)
def autolog(
log_input_examples=False,
log_model_signatures=False,
log_models=False,
log_datasets=False,
disable=False,
exclusive=False,
disable_for_unsupported_versions=False,
silent=False,
registered_model_name=None,
extra_tags=None,
log_traces=True,
):
"""
Enables (or disables) and configures autologging from OpenAI to MLflow.
Raises :py:class:`MlflowException <mlflow.exceptions.MlflowException>`
if the OpenAI version < 1.0.
Args:
log_input_examples: If ``True``, input examples from inference data are collected and
logged along with Langchain model artifacts during inference. If
``False``, input examples are not logged.
Note: Input examples are MLflow model attributes
and are only collected if ``log_models`` is also ``True``.
log_model_signatures: If ``True``,
:py:class:`ModelSignatures <mlflow.models.ModelSignature>`
describing model inputs and outputs are collected and logged along
with OpenAI model artifacts during inference. If ``False``,
signatures are not logged.
Note: Model signatures are MLflow model attributes
and are only collected if ``log_models`` is also ``True``.
log_models: If ``True``, OpenAI models are logged as MLflow model artifacts.
If ``False``, OpenAI models are not logged.
Input examples and model signatures, which are attributes of MLflow models,
are also omitted when ``log_models`` is ``False``.
log_datasets: If ``True``, dataset information is logged to MLflow Tracking
if applicable. If ``False``, dataset information is not logged.
disable: If ``True``, disables the OpenAI autologging integration. If ``False``,
enables the OpenAI autologging integration.
exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
If ``False``, autologged content is logged to the active fluent run,
which may be user-created.
disable_for_unsupported_versions: If ``True``, disable autologging for versions of
OpenAI that have not been tested against this version of the MLflow
client or are incompatible.
silent: If ``True``, suppress all event logs and warnings from MLflow during OpenAI
autologging. If ``False``, show all events and warnings during OpenAI
autologging.
registered_model_name: If given, each time a model is trained, it is registered as a
new model version of the registered model with this name.
The registered model is created if it does not already exist.
extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
log_traces: If ``True``, traces are logged for OpenAI models. If ``False``, no traces are
collected during inference. Default to ``True``.
"""
if Version(_get_openai_package_version()).major < 1:
raise MlflowException("OpenAI autologging is only supported for openai >= 1.0.0")
from openai.resources.chat.completions import Completions as ChatCompletions
from openai.resources.completions import Completions
from openai.resources.embeddings import Embeddings
for task in (ChatCompletions, Completions, Embeddings):
safe_patch(
FLAVOR_NAME,
task,
"create",
patched_call,
)