import json
import posixpath
from typing import Any, Dict, Iterator, Optional
from mlflow.deployments import BaseDeploymentClient
from mlflow.deployments.constants import (
MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
)
from mlflow.environment_variables import (
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT,
MLFLOW_HTTP_REQUEST_TIMEOUT,
)
from mlflow.exceptions import MlflowException
from mlflow.utils import AttrDict
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.rest_utils import augmented_raise_for_status, http_request
[docs]class DatabricksEndpoint(AttrDict):
"""
A dictionary-like object representing a Databricks serving endpoint.
.. code-block:: python
endpoint = DatabricksEndpoint(
{
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
}
)
assert endpoint.name == "chat"
"""
[docs]@experimental
class DatabricksDeploymentClient(BaseDeploymentClient):
"""
Client for interacting with Databricks serving endpoints.
Example:
First, set up credentials for authentication:
.. code-block:: bash
export DATABRICKS_HOST=...
export DATABRICKS_TOKEN=...
.. seealso::
See https://docs.databricks.com/en/dev-tools/auth.html for other authentication methods.
Then, create a deployment client and use it to interact with Databricks serving endpoints:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoints = client.list_endpoints()
assert endpoints == [
{
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
},
]
"""
[docs] def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `DatabricksDeploymentClient`.
"""
raise NotImplementedError
[docs] def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `DatabricksDeploymentClient`.
"""
raise NotImplementedError
[docs] def delete_deployment(self, name, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `DatabricksDeploymentClient`.
"""
raise NotImplementedError
[docs] def list_deployments(self, endpoint=None):
"""
.. warning::
This method is not implemented for `DatabricksDeploymentClient`.
"""
raise NotImplementedError
[docs] def get_deployment(self, name, endpoint=None):
"""
.. warning::
This method is not implemented for `DatabricksDeploymentClient`.
"""
raise NotImplementedError
def _call_endpoint(
self,
*,
method: str,
prefix: str = "/api/2.0",
route: Optional[str] = None,
json_body: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None,
):
call_kwargs = {}
if method.lower() == "get":
call_kwargs["params"] = json_body
else:
call_kwargs["json"] = json_body
response = http_request(
host_creds=get_databricks_host_creds(self.target_uri),
endpoint=posixpath.join(prefix, "serving-endpoints", route or ""),
method=method,
timeout=MLFLOW_HTTP_REQUEST_TIMEOUT.get() if timeout is None else timeout,
raise_on_status=False,
retry_codes=MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
extra_headers={"X-Databricks-Endpoints-API-Client": "Databricks Deployment Client"},
**call_kwargs,
)
augmented_raise_for_status(response)
return DatabricksEndpoint(response.json())
def _call_endpoint_stream(
self,
*,
method: str,
prefix: str = "/api/2.0",
route: Optional[str] = None,
json_body: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None,
) -> Iterator[str]:
call_kwargs = {}
if method.lower() == "get":
call_kwargs["params"] = json_body
else:
call_kwargs["json"] = json_body
response = http_request(
host_creds=get_databricks_host_creds(self.target_uri),
endpoint=posixpath.join(prefix, "serving-endpoints", route or ""),
method=method,
timeout=MLFLOW_HTTP_REQUEST_TIMEOUT.get() if timeout is None else timeout,
raise_on_status=False,
retry_codes=MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
extra_headers={"X-Databricks-Endpoints-API-Client": "Databricks Deployment Client"},
stream=True, # Receive response content in streaming way.
**call_kwargs,
)
augmented_raise_for_status(response)
# Streaming response content are composed of multiple lines.
# Each line format depends on specific endpoint
return (
line.strip()
for line in response.iter_lines(decode_unicode=True)
if line.strip() # filter out keep-alive new lines
)
[docs] @experimental
def predict(self, deployment_name=None, inputs=None, endpoint=None):
"""
Query a serving endpoint with the provided model inputs.
See https://docs.databricks.com/api/workspace/servingendpoints/query for request/response
schema.
Args:
deployment_name: Unused.
inputs: A dictionary containing the model inputs to query.
endpoint: The name of the serving endpoint to query.
Returns:
A :py:class:`DatabricksEndpoint` object containing the query response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
response = client.predict(
endpoint="chat",
inputs={
"messages": [
{"role": "user", "content": "Hello!"},
],
},
)
assert response == {
"id": "chatcmpl-8OLm5kfqBAJD8CpsMANESWKpLSLXY",
"object": "chat.completion",
"created": 1700814265,
"model": "gpt-4-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 9,
"total_tokens": 18,
},
}
"""
return self._call_endpoint(
method="POST",
prefix="/",
route=posixpath.join(endpoint, "invocations"),
json_body=inputs,
timeout=MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get(),
)
[docs] @experimental
def predict_stream(
self, deployment_name=None, inputs=None, endpoint=None
) -> Iterator[Dict[str, Any]]:
"""
Submit a query to a configured provider endpoint, and get streaming response
Args:
deployment_name: Unused.
inputs: The inputs to the query, as a dictionary.
endpoint: The name of the endpoint to query.
Returns:
An iterator of dictionary containing the response from the endpoint.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
chunk_iter = client.predict_stream(
endpoint="databricks-llama-2-70b-chat",
inputs={
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.0,
"n": 1,
"max_tokens": 500,
},
)
for chunk in chunk_iter:
print(chunk)
# Example:
# {
# "id": "82a834f5-089d-4fc0-ad6c-db5c7d6a6129",
# "object": "chat.completion.chunk",
# "created": 1712133837,
# "model": "llama-2-70b-chat-030424",
# "choices": [
# {
# "index": 0, "delta": {"role": "assistant", "content": "Hello"},
# "finish_reason": None,
# }
# ],
# "usage": {"prompt_tokens": 11, "completion_tokens": 1, "total_tokens": 12},
# }
"""
inputs = inputs or {}
# Add stream=True param in request body to get streaming response
# See https://docs.databricks.com/api/workspace/servingendpoints/query#stream
chunk_line_iter = self._call_endpoint_stream(
method="POST",
prefix="/",
route=posixpath.join(endpoint, "invocations"),
json_body={**inputs, "stream": True},
timeout=MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get(),
)
for line in chunk_line_iter:
splits = line.split(":", 1)
if len(splits) < 2:
raise MlflowException(
f"Unknown response format: '{line}', "
"expected 'data: <value>' for streaming response."
)
key, value = splits
if key != "data":
raise MlflowException(
f"Unknown response format with key '{key}'. "
f"Expected 'data: <value>' for streaming response, got '{line}'."
)
value = value.strip()
if value == "[DONE]":
# Databricks endpoint streaming response ends with
# a line of "data: [DONE]"
return
yield json.loads(value)
[docs] @experimental
def create_endpoint(self, name, config=None):
"""
Create a new serving endpoint with the provided name and configuration.
See https://docs.databricks.com/api/workspace/servingendpoints/create for request/response
schema.
Args:
name: The name of the serving endpoint to create.
config: A dictionary containing the configuration of the serving endpoint to create.
Returns:
A :py:class:`DatabricksEndpoint` object containing the request response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoint = client.create_endpoint(
name="chat",
config={
"served_entities": [
{
"name": "test",
"external_model": {
"name": "gpt-4",
"provider": "openai",
"task": "llm/v1/chat",
"openai_config": {
"openai_api_key": "{{secrets/scope/key}}",
},
},
}
],
},
)
assert endpoint == {
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
}
"""
config = config.copy() if config else {} # avoid mutating config
extras = {}
for key in ("tags", "rate_limits"):
if tags := config.pop(key, None):
extras[key] = tags
payload = {"name": name, "config": config, **extras}
return self._call_endpoint(method="POST", json_body=payload)
[docs] @experimental
def update_endpoint(self, endpoint, config=None):
"""
Update a specified serving endpoint with the provided configuration.
See https://docs.databricks.com/api/workspace/servingendpoints/updateconfig for
request/response schema.
Args:
endpoint: The name of the serving endpoint to update.
config: A dictionary containing the configuration of the serving endpoint to update.
Returns:
A :py:class:`DatabricksEndpoint` object containing the request response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoint = client.update_endpoint(
endpoint="chat",
config={
"served_entities": [
{
"name": "test",
"external_model": {
"name": "gpt-4",
"provider": "openai",
"task": "llm/v1/chat",
"openai_config": {
"openai_api_key": "{{secrets/scope/key}}",
},
},
}
],
},
)
assert endpoint == {
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
}
rate_limits = client.update_endpoint(
endpoint="chat",
config={
"rate_limits": [
{
"key": "user",
"renewal_period": "minute",
"calls": 10,
}
],
},
)
assert rate_limits == {
"rate_limits": [
{
"key": "user",
"renewal_period": "minute",
"calls": 10,
}
],
}
"""
if list(config) == ["rate_limits"]:
return self._call_endpoint(
method="PUT", route=posixpath.join(endpoint, "rate-limits"), json_body=config
)
else:
return self._call_endpoint(
method="PUT", route=posixpath.join(endpoint, "config"), json_body=config
)
[docs] @experimental
def delete_endpoint(self, endpoint):
"""
Delete a specified serving endpoint.
See https://docs.databricks.com/api/workspace/servingendpoints/delete for request/response
schema.
Args:
endpoint: The name of the serving endpoint to delete.
Returns:
A DatabricksEndpoint object containing the request response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
client.delete_endpoint(endpoint="chat")
"""
return self._call_endpoint(method="DELETE", route=endpoint)
[docs] @experimental
def list_endpoints(self):
"""
Retrieve all serving endpoints.
See https://docs.databricks.com/api/workspace/servingendpoints/list for request/response
schema.
Returns:
A list of :py:class:`DatabricksEndpoint` objects containing the request response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoints = client.list_endpoints()
assert endpoints == [
{
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
},
]
"""
return self._call_endpoint(method="GET").endpoints
[docs] @experimental
def get_endpoint(self, endpoint):
"""
Get a specified serving endpoint.
See https://docs.databricks.com/api/workspace/servingendpoints/get for request/response
schema.
Args:
endpoint: The name of the serving endpoint to get.
Returns:
A DatabricksEndpoint object containing the request response.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoint = client.get_endpoint(endpoint="chat")
assert endpoint == {
"name": "chat",
"creator": "alice@company.com",
"creation_timestamp": 0,
"last_updated_timestamp": 0,
"state": {...},
"config": {...},
"tags": [...],
"id": "88fd3f75a0d24b0380ddc40484d7a31b",
}
"""
return self._call_endpoint(method="GET", route=endpoint)
def run_local(name, model_uri, flavor=None, config=None):
pass
def target_help():
pass