Last active
November 30, 2022 16:32
-
-
Save scott2b/4a27db553ba4f94f8fb094a9436e90fd to your computer and use it in GitHub Desktop.
A simple authentication proxy for mlauth implemented in Starlette
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
**** | |
CAUTION: This is the wrong approach and will not work for all requests! | |
Instead of parsing out the json data, the content body should just simply be passed | |
through to mlflow. I will get a fix in soon. | |
**** | |
A simple proxy to put an authentication layer in front of mlflow. | |
requires [Starlette](https://www.starlette.io/) | |
mlflow does not provide authentication and it is a common use case to need to secure | |
mlflow deployments. There are some mlflow specific things here, but this gist could | |
easily be adapted to other similar use cases. | |
I wrote most of this before I ran across the mlflow-easyauth project. That may be a better | |
choice, although I have not tried it. Glancing at that code, it is not clear to me how | |
mlflow-easyauth handles the UI which does not have a mechanism for passing authentication | |
in requests, as far as I know. The approach here is to introduce login and logout | |
endpoints in the proxy along with a middleware. A session is created for the user | |
after checking the Basic authentication credentials. It is destroyed by visiting /logout. | |
Unfortunately, I do not see a way to add login and logout to the UI navigation. | |
If you specifically are using OAuth, you might consider something like | |
[OAuth2 Proxy](https://oauth2-proxy.github.io/oauth2-proxy/) for mlflow, although I | |
have not tried it myself, and as far as I can tell the mlflow clients do not have a | |
mechanism for passing credentials other than Basic auth headers, so you would probably | |
have to implement your own client. | |
The general assumption here is that SITE_ROOT is only internally accessible. The | |
mechanism for protecting that endpoint will be specific to your setup, but will | |
typically be something network-y like firewall rules, IP allow-lists, etc. | |
## Env variables | |
The following need to be set in the environment of your client for programmatic API access: | |
MLFLOW_TRACKING_USERNAME=username | |
MLFLOW_TRACKING_PASSWORD=password | |
Execute with: | |
``` | |
uvicorn mlflowproxy:app --host 0.0.0.0 --port 5002 | |
``` | |
## Related resources | |
- [mlflow](https://mlflow.org/) | |
- [Starlette](https://www.starlette.io/) | |
- [mlflow-easyauth](https://github.com/soundsensing/mlflow-easyauth). Basic auth for mlflow | |
""" | |
import json | |
import httpx | |
from base64 import b64decode | |
from starlette.applications import Starlette | |
from starlette.authentication import requires | |
from starlette.routing import Route | |
from starlette.responses import Response, RedirectResponse | |
from starlette.authentication import AuthenticationBackend, AuthCredentials | |
from starlette.authentication import SimpleUser, BaseUser, UnauthenticatedUser | |
from starlette.middleware.authentication import AuthenticationMiddleware | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.responses import HTMLResponse | |
DEBUG = True | |
PROXY_ROOT = "http://localhost:5002" # Set these for your proxy and site. | |
SITE_ROOT = "http://localhost:5000" # See notes above about protecting your site. | |
NO_DATA_METHODS = ["GET", "HEAD"] | |
DATA_METHODS = [ | |
"POST", | |
"PUT", | |
"PATCH", | |
"DELETE", | |
] # mlflow sends data for delete requests | |
ALL_METHODS = NO_DATA_METHODS + DATA_METHODS | |
REMOVE_REQUEST_HEADERS = ["content-length"] | |
SECRET_KEY = "supersecretchangeme" # used by session middleware | |
def decode_password(header): | |
basic, cred = header.split() | |
assert basic.lower() == "basic" | |
return b64decode(cred).decode().split(":", 1) | |
def clean_request_headers(headers): | |
return {k: v for k, v in headers.items() if k not in REMOVE_REQUEST_HEADERS} | |
async def json_data(request): | |
try: | |
return await request.json() | |
except (AttributeError, json.decoder.JSONDecodeError): | |
raise | |
return None | |
# Probably don't do this in production | |
USERS = { | |
"user1": "password1", | |
"user2": "password2" | |
} | |
class DevAuthBackend(AuthenticationBackend): | |
"""For dev and testing, assumes username=username and password=password.""" | |
async def authenticate(self, conn): | |
"""Note that Starlette's sessions send session data to the client, so the | |
username is readable but not modifiable. You may want to use an alternative | |
session middleware that stores session data on the server or otherwise obfuscate | |
the user's identity in the session data. | |
""" | |
if "username" in conn.session and conn.session["username"] in USERS: | |
return AuthCredentials(["api_auth", "app_auth"]), SimpleUser( | |
conn.session["username"] | |
) | |
if "authorization" not in conn.headers: | |
return | |
username, password = decode_password(conn.headers["authorization"]) | |
if password == USERS.get(username): | |
conn.session["username"] = username | |
return AuthCredentials(["api_auth", "app_auth"]), SimpleUser(username) | |
class AuthBackend(AuthenticationBackend): | |
async def authenticate(self, request): | |
"""Implement this with your authentication requirements. Presumably, you would | |
do something a bit more sophisticated here than a hard-coded dictionary. | |
See: https://www.starlette.io/authentication/ | |
""" | |
raise NotImplementedError | |
# AUTH_BACKEND = AuthBackend # Needs to be implemented | |
AUTH_BACKEND = DevAuthBackend # For dev purposes only! | |
async def dispatch(url, method, **kwargs): | |
async with httpx.AsyncClient() as client: | |
return await client.request(url=url, method=method, **kwargs) | |
async def main(request): | |
if not request.user.is_authenticated: | |
return RedirectResponse(request.url_for("login")) | |
if "app_auth" not in request.auth.scopes: | |
return RedirectResponse(request.url_for("logout")) | |
url = SITE_ROOT + str(request.url)[len(PROXY_ROOT) :] | |
req_headers = clean_request_headers(request.headers) | |
if request.method in NO_DATA_METHODS: | |
resp = await dispatch(url, request.method, headers=req_headers) | |
elif request.method in DATA_METHODS: | |
resp = await dispatch( | |
url, request.method, headers=req_headers, json=await json_data(request) | |
) | |
return Response( | |
resp.content, | |
headers=dict(resp.headers), | |
media_type=resp.headers.get("content-type"), | |
status_code=resp.status_code, | |
) | |
@requires("api_auth") | |
async def api(request): | |
return await main(request) | |
async def logout(request): | |
request.session.clear() | |
return RedirectResponse(request.url_for("login")) | |
async def login(request): | |
if request.method == "POST": | |
if request.user.is_authenticated: | |
return Response(status_code=200) | |
else: | |
return Response(status_code=403) | |
if request.user.is_authenticated: | |
return RedirectResponse(request.url_for("home")) | |
return HTMLResponse( | |
content="""<html> | |
<head> | |
</head> | |
<body> | |
<form id="loginForm"> | |
<input type="text" name="username" placeholder="username"> | |
<input type="password" name="password" placeholder="password"> | |
<input type="submit" value="Login"> | |
</form> | |
<div id="message" /> | |
<script> | |
loginForm.onsubmit = async (e) => { | |
e.preventDefault(); | |
let data = new FormData(loginForm); | |
let response = await fetch('/login', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
'Authorization': 'Basic ' + btoa(data.get("username") + ':' + data.get("password")) | |
} | |
}); | |
if (response.status == 200) { | |
window.location.replace("/"); | |
} else if (response.status == 403) { | |
message.innerText = "Incorrect username or password"; | |
} else { | |
message.innerText = "Unknown error. Unable to log in."; | |
} | |
}; | |
</script> | |
</body> | |
</html> | |
""", | |
status_code=200, | |
) | |
""" | |
Routing here is specific to mlflow and is designed to allow for api-key access to the | |
API with the auth proper backend implementation above. For use cases other than mlflow, | |
if there is not an API, remove the ajax-api and api routes and all requests will go to main. | |
mlflow appears to have multiple internal API roots which are handled here. | |
""" | |
routes = [ | |
Route("/login", login, methods=["GET", "POST"], name="login"), | |
Route("/logout", logout, methods=["GET"], name="logout"), | |
Route("/ajax-api/{path:path}", api, methods=ALL_METHODS), | |
Route("/api/{path:path}", api, methods=ALL_METHODS), | |
Route("/{path:path}", main, methods=ALL_METHODS), | |
Route("/", main, methods=ALL_METHODS, name="home"), | |
] | |
app = Starlette(debug=DEBUG, routes=routes) | |
app.add_middleware(AuthenticationMiddleware, backend=AUTH_BACKEND()) | |
# See the Starlette SessionMiddleware docs: https://www.starlette.io/middleware/#sessionmiddleware | |
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, https_only=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment