import logging
import os
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Generator, Union
from mlflow.environment_variables import MLFLOW_TRACKING_URI
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
from mlflow.store.tracking.file_store import FileStore
from mlflow.store.tracking.rest_store import RestStore
from mlflow.tracing.provider import reset_tracer_setup
from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry
from mlflow.utils.credentials import get_default_host_creds
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.file_utils import path_to_local_file_uri
from mlflow.utils.uri import _DATABRICKS_UNITY_CATALOG_SCHEME
_logger = logging.getLogger(__name__)
_tracking_uri = None
[docs]def is_tracking_uri_set():
"""Returns True if the tracking URI has been set, False otherwise."""
if _tracking_uri or MLFLOW_TRACKING_URI.get():
return True
return False
[docs]def set_tracking_uri(uri: Union[str, Path]) -> None:
"""
Set the tracking server URI. This does not affect the
currently active run (if one exists), but takes effect for successive runs.
Args:
uri:
- An empty string, or a local file path, prefixed with ``file:/``. Data is stored
locally at the provided file (or ``./mlruns`` if empty).
- An HTTP URI like ``https://my-tracking-server:5000``.
- A Databricks workspace, provided as the string "databricks" or, to use a Databricks
CLI `profile <https://github.com/databricks/databricks-cli#installation>`_,
"databricks://<profileName>".
- A :py:class:`pathlib.Path` instance
.. code-block:: python
:test:
:caption: Example
import mlflow
mlflow.set_tracking_uri("file:///tmp/my_tracking")
tracking_uri = mlflow.get_tracking_uri()
print(f"Current tracking uri: {tracking_uri}")
.. code-block:: text
:caption: Output
Current tracking uri: file:///tmp/my_tracking
"""
if isinstance(uri, Path):
# On Windows with Python3.8 (https://bugs.python.org/issue38671)
# .resolve() doesn't return the absolute path if the directory doesn't exist
# so we're calling .absolute() first to get the absolute path on Windows,
# then .resolve() to clean the path
uri = uri.absolute().resolve().as_uri()
global _tracking_uri
if _tracking_uri != uri:
_tracking_uri = uri
# Tracer provider uses tracking URI to determine where to export traces.
# Tracer provider stores the URI as its state so we need to reset
# it explicitly when the global tracking URI changes.
reset_tracer_setup()
@contextmanager
def _use_tracking_uri(uri: str) -> Generator[None, None, None]:
"""Temporarily use the specified tracking URI.
Args:
uri: The tracking URI to use.
"""
global _tracking_uri
old_tracking_uri = _tracking_uri
try:
set_tracking_uri(uri)
yield
finally:
set_tracking_uri(old_tracking_uri)
def _resolve_tracking_uri(tracking_uri=None):
return tracking_uri or get_tracking_uri()
[docs]def get_tracking_uri() -> str:
"""Get the current tracking URI. This may not correspond to the tracking URI of
the currently active run, since the tracking URI can be updated via ``set_tracking_uri``.
Returns:
The tracking URI.
.. code-block:: python
import mlflow
# Get the current tracking uri
tracking_uri = mlflow.get_tracking_uri()
print(f"Current tracking uri: {tracking_uri}")
.. code-block:: text
Current tracking uri: file:///.../mlruns
"""
global _tracking_uri
if _tracking_uri is not None:
return _tracking_uri
elif uri := MLFLOW_TRACKING_URI.get():
return uri
else:
return path_to_local_file_uri(os.path.abspath(DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH))
def _get_file_store(store_uri, **_):
return FileStore(store_uri, store_uri)
def _get_sqlalchemy_store(store_uri, artifact_uri):
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
if artifact_uri is None:
artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
return SqlAlchemyStore(store_uri, artifact_uri)
def _get_rest_store(store_uri, **_):
return RestStore(partial(get_default_host_creds, store_uri))
def _get_databricks_rest_store(store_uri, **_):
return RestStore(partial(get_databricks_host_creds, store_uri))
def _get_databricks_uc_rest_store(store_uri, **_):
from mlflow.exceptions import MlflowException
from mlflow.version import VERSION
global _tracking_store_registry
supported_schemes = [
scheme
for scheme in _tracking_store_registry._registry
if scheme != _DATABRICKS_UNITY_CATALOG_SCHEME
]
raise MlflowException(
f"Detected Unity Catalog tracking URI '{store_uri}'. "
"Setting the tracking URI to a Unity Catalog backend is not supported in the current "
f"version of the MLflow client ({VERSION}). "
"Please specify a different tracking URI via mlflow.set_tracking_uri, with "
"one of the supported schemes: "
f"{supported_schemes}. If you're trying to access models in the Unity "
"Catalog, please upgrade to the latest version of the MLflow Python "
"client, then specify a Unity Catalog model registry URI via "
f"mlflow.set_registry_uri('{_DATABRICKS_UNITY_CATALOG_SCHEME}') or "
f"mlflow.set_registry_uri('{_DATABRICKS_UNITY_CATALOG_SCHEME}://profile_name'), where "
"'profile_name' is the name of the Databricks CLI profile to use for "
"authentication. Be sure to leave the tracking URI configured to use "
"one of the supported schemes listed above."
)
_tracking_store_registry = TrackingStoreRegistry()
def _register_tracking_stores():
global _tracking_store_registry
_tracking_store_registry.register("", _get_file_store)
_tracking_store_registry.register("file", _get_file_store)
_tracking_store_registry.register("databricks", _get_databricks_rest_store)
_tracking_store_registry.register(
_DATABRICKS_UNITY_CATALOG_SCHEME, _get_databricks_uc_rest_store
)
for scheme in ["http", "https"]:
_tracking_store_registry.register(scheme, _get_rest_store)
for scheme in DATABASE_ENGINES:
_tracking_store_registry.register(scheme, _get_sqlalchemy_store)
_tracking_store_registry.register_entrypoints()
def _register(scheme, builder):
_tracking_store_registry.register(scheme, builder)
_register_tracking_stores()
def _get_store(store_uri=None, artifact_uri=None):
return _tracking_store_registry.get_store(store_uri, artifact_uri)
_artifact_repos_cache = OrderedDict()
def _get_artifact_repo(run_id):
return _artifact_repos_cache.get(run_id)
# TODO(sueann): move to a projects utils module
def _get_git_url_if_present(uri):
"""Return the path git_uri#sub_directory if the URI passed is a local path that's part of
a Git repo, or returns the original URI otherwise.
Args:
uri: The expanded uri.
Returns:
The git_uri#sub_directory if the uri is part of a Git repo, otherwise return the original
uri.
"""
if "#" in uri:
# Already a URI in git repo format
return uri
try:
from git import GitCommandNotFound, InvalidGitRepositoryError, NoSuchPathError, Repo
except ImportError as e:
_logger.warning(
"Failed to import Git (the git executable is probably not on your PATH),"
" so Git SHA is not available. Error: %s",
e,
)
return uri
try:
# Check whether this is part of a git repo
repo = Repo(uri, search_parent_directories=True)
# Repo url
repo_url = f"file://{repo.working_tree_dir}"
# Sub directory
rlpath = uri.replace(repo.working_tree_dir, "")
if rlpath == "":
git_path = repo_url
elif rlpath[0] == "/":
git_path = repo_url + "#" + rlpath[1:]
else:
git_path = repo_url + "#" + rlpath
return git_path
except (InvalidGitRepositoryError, GitCommandNotFound, ValueError, NoSuchPathError):
return uri