This gist sets out proposals for modifying the interfaces to the MLflow stores and artifact repositories. Currently, different implementations of these require different arguments, and the proposals below attempt to make them more consistent.
These changes are motivated in particular by a proposed new plugin system for third party packages to provide implementations of tracking stores and artifact repositories.
There are currently three implementations of tracking store, with the following interfaces:
FileStore(root_directory, artifact_root_uri)
SqlAlchemyStore(db_uri, default_artifact_root)
RestStore(get_host_creds)
In each case, the first argument provides information on the location of
tracking information storage. The second argument to FileStore
and
SqlAlchemyStore
sets the root directory within which artfifacts will be
stored. This is overridable when creating an experiment.
Currently, mlflow.tracking.utils._get_store()
is responsible for mapping a
store URI to an instance of one of the above stores. This includes mapping the
scheme from the URI to the relevant store class, plus generating the correct
inputs to be passed when instantiating the class.
There are currently seven implementations of artifact repository, with the following interfaces:
LocalArtifactRepository(artifact_uri)
S3ArtifactRepository(artifact_uri)
GCSArtifactRepository(artifact_uri)
AzureBlobArtifactRepository(artifact_uri)
FTPArtifactRepository(artifact_uri)
SFTPArtifactRepository(artifact_uri)
DbfsArtifactRepository(artifact_uri, get_host_creds)
All artifact repositories take a URI to identify the location of stored
artifacts, with DbfsArtifactRepository
requiring an additional
get_host_creds
, which is taken from an associated RestStore
tracking store.
Currently, mlflow.store.artifact_repo.ArtifactRepository.from_artifact_uri
is
responsible for mapping an artifact URI to one of the above stores. In order to
support building DbfsArtifactRepository
, this method requires the caller to
pass a tracking store instance, which must be a RestStore
when building a
DbfsArtifactRepository
, and is ignored otherwise.
We are proposing to make the following changes:
- Have the
RestStore
andDbfsArtifactRepository
take a URI as their sole argument, and do the conversion of URI to host credentials internally, using a common helper function. - Remove the
store
argument frommlflow.store.artifact_repo.ArtifactRepository.from_artifact_uri
. - Require all implementations of tracking store and artifact repository to
accept arbitrary keyword arguments (using
**kwargs
), which would be passed through bymlflow.tracking.utils._get_store
. - Make the
default_artifact_root
argument to theSqlAlchemyStore
optional and rename the optionalartifact_root_uri
argument toFileStore
todefault_artifact_root
.
This would provide the following advantages:
- The interface for all tracking store and artifact repositories will be the
same:
StoreOrRepo(uri, **kwargs)
. - As a result,
mlflow.tracking.utils._get_store
will not need to know different logic for building the arguments to different tracking store and artifact repository implementations. - The implementation of artifact repositories will be decoupled from tracking stores.
- Forwards compatibility will be provided for store implementations that require inputs other than the URI.
The proposals above require the tracking store implementations to be
responsible for loading their own authentication credentials. To enable this,
we suggest separating the RestStore
into two classes, one corresponding to
the http
/https
schemes and one corresponding to the databricks
scheme.
These would likely be subclasses of a common parent which contains the REST logic, with the subclasses only implementing the loading of host credentials. Below is an illustrative example of how this could be done:
from mlflow.store.abstract_store import AbstractStore
from mlflow.utils import rest_utils
from mlflow.utils.databricks_utils import get_databricks_host_creds
class AbstractRestStore(AbstractStore):
def __init__(self, store_uri):
super(AbstractRestStore, self).__init__()
self.store_uri = store_uri
@abstractmethod
def _get_host_creds(self):
pass
# Current RestStore functionality goes here
class RestStore(AbstractRestStore):
def _get_host_creds(self):
# Currently in mlflow.tracking.utils._get_rest_store
return rest_utils.MlflowHostCreds(
host=self.store_uri,
username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
ignore_tls_verification=os.environ.get(
_TRACKING_INSECURE_TLS_ENV_VAR
) == 'true',
)
class DatabricksRestStore(AbstractRestStore):
def _get_host_creds(self):
# Get the Databricks profile specified by the tracking URI
parsed_uri = urllib.parse.urlparse(self.store_uri)
profile = parsed_uri.netloc
return get_databricks_host_creds(profile)
By making DbfsArtifactRepository
responsible for looking up host credentials
itself, we can make its interface consistent with the other artifact repository
implementations:
import os
from mlflow.utils import rest_utils
from mlflow.store.artifact_repo import ArtifactRepository
from mlflow.utils.databricks_utils import get_databricks_host_creds
# Copied from mlflow.tracking.utils for demonstration:
_TRACKING_URI_ENV_VAR = "MLFLOW_TRACKING_URI"
_TRACKING_USERNAME_ENV_VAR = "MLFLOW_TRACKING_USERNAME"
_TRACKING_PASSWORD_ENV_VAR = "MLFLOW_TRACKING_PASSWORD"
_TRACKING_TOKEN_ENV_VAR = "MLFLOW_TRACKING_TOKEN"
_TRACKING_INSECURE_TLS_ENV_VAR = "MLFLOW_TRACKING_INSECURE_TLS"
class DbfsArtifactRepository(ArtifactRepository):
def __init__(self, artifact_uri):
# get_host_creds argument removed
cleaned_artifact_uri = artifact_uri.rstrip('/')
super(DbfsArtifactRepository, self).__init__(cleaned_artifact_uri)
if not cleaned_artifact_uri.startswith('dbfs:/'):
raise MlflowException(
'DbfsArtifactRepository URI must start with dbfs:/'
)
def _get_host_creds(self):
tracking_uri_from_environment = os.environ.get(_TRACKING_URI_ENV_VAR)
if tracking_uri_from_environment is not None:
# This enables interoperability with the Java client, which passes
# this information in the environment
host_creds = rest_utils.MlflowHostCreds(
host=tracking_uri_from_environment,
username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
ignore_tls_verification=os.environ.get(
_TRACKING_INSECURE_TLS_ENV_VAR
) == 'true',
)
else:
# DBFS URIs do not currently contain a Databricks profile, so load
# the default profile
host_creds = get_databricks_host_creds()
return host_creds
# Current DbfsArtifactRepository functionality goes here, with calls to
# self.get_host_creds replaced with calls to self._get_host_creds
Note: The above implementation reimplements logic currently in
mlflow.tracking.utils._get_rest_store
, but it's intended that some
refactoring would place repeated logic in a common location.
This will the permit the removal of the store
argument from
mlflow.store.artifact_repo.ArtifactRepository.from_artifact_uri
:
class ArtifactRepository:
# ...
@staticmethod
def from_artifact_uri(artifact_uri):
if artifact_uri.startswith("s3:/"):
# Import these locally to avoid creating a circular import loop
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
return S3ArtifactRepository(artifact_uri)
elif artifact_uri.startswith("gs:/"):
from mlflow.store.gcs_artifact_repo import GCSArtifactRepository
return GCSArtifactRepository(artifact_uri)
elif artifact_uri.startswith("wasbs:/"):
from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository
return AzureBlobArtifactRepository(artifact_uri)
elif artifact_uri.startswith("ftp:/"):
from mlflow.store.ftp_artifact_repo import FTPArtifactRepository
return FTPArtifactRepository(artifact_uri)
elif artifact_uri.startswith("sftp:/"):
from mlflow.store.sftp_artifact_repo import SFTPArtifactRepository
return SFTPArtifactRepository(artifact_uri)
elif artifact_uri.startswith("dbfs:/"):
from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository
return DbfsArtifactRepository(artifact_uri)
else:
from mlflow.store.local_artifact_repo import LocalArtifactRepository
return LocalArtifactRepository(artifact_uri)
Proposal 3 above provides a mechanism for tracking store and artifact
repository implementations to allow additional customisation options that do
not need to be supported by all implementations, and proposal 4 uses this new
mechanism to allow callers to customise the default artifact location in the
FileStore
and SqlAlchemyStore
implementations.
The abstract base class for artifact respositories would have its __init__
updated accordingly:
from abc import ABCMeta
class ArtifactRepository:
__metaclass__ = ABCMeta
def __init__(self, artifact_uri, **_):
# **_ used instead of **kwargs to keep linters happy with these not
# being used
self.artifact_uri = artifact_uri
And all current implementations of both tracking stores and artifact repositories would be updated to accept (and ignore) any passed extra keyword arguments. For example, in the FTP artifact repository:
from six.moves import urllib
from mlflow.store.artifact_repo import ArtifactRepository
class FTPArtifactRepository(ArtifactRepository):
def __init__(self, artifact_uri, **_):
self.uri = artifact_uri
parsed = urllib.parse.urlparse(artifact_uri)
self.config = {
'host': parsed.hostname,
'port': 21 if parsed.port is None else parsed.port,
'username': parsed.username,
'password': parsed.password
}
self.path = parsed.path
if self.config['host'] is None:
self.config['host'] = 'localhost'
super(FTPArtifactRepository, self).__init__(artifact_uri)
The interface for building a FileStore
would look like:
class FileStore(AbstractStore):
def __init__(self, store_uri, default_artifact_root=None, **_):
# Existing implementation with variable names updated and existing
# logic in case that default_artifact_root is None retained
And the SqlAlchemyStore would look like:
from mlflow.store import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
class SqlAlchemyStore(AbstractStore):
def __init__(self, store_uri, default_artifact_root=None, **_):
# Implement the same default as in mlflow.tracking.utils._get_store
# currently:
if default_artifact_root is None:
default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
# Existing implementation with variable names updated goes here
Note: In this implementation, both mlflow.tracking.utils._get_store
and
mlflow.server.handlers._get_store
remain responsible for mapping environment
variables to default artifact locations. They will read the same environment
variables as they do at the moment, and pass the keyword argument
default_artifact_root
to all store implementations on construction, leaving
each store responsible for using that information or not.
The main issue with this proposal is that interacting with DBFS artifact locations that require different credentials than those in the user's default Databricks profile does not work. This is because DBFS URIs are context-dependent and do not contain information about which profile / credentials to use.
Ultimately, if we want to remove the existing coupling between the artifact
repositories and tracking stores, the DbfsArtifactRepository
will need to be
responsible for mapping its passed DBFS URI to relevant host credentials by
itself. This decoupling brings strong advantages as developers can reason about
these parts of the system by themselves, and don't have confusing siutations
like not being able to use a DbfsArtifactRepository
without having a
RestStore
-compatible tracking URI.
For the moment, our proposed change may be acceptable as-is if the context in which Databricks users are using MLflow configures the default profile and tracking URI for them. In this case, the default profile will likely be correct for both the REST tracking store and DBFS tracking URI.
The 'best' long term solution would be for the DBFS URIs returned by the
tracking store to have enough information to make them universally unique, for
instance by including the Databricks profile name or some hostname in the
'netloc' part of the URI, as done currently in databricks
scheme tracking
URIs.
Is there a plan to update
mlflow.server.handlers._get_store
to callmlflow.tracking.utils._get_store
and optionally move the environment variables into this utils function? Currently the handlers call doesn't support plugins.