Last active
March 27, 2024 16:32
-
-
Save kellerza/8aad3952086b827a9f32516373df1623 to your computer and use it in GitHub Desktop.
AsyncIO based OAuth Authorization Code Flow using the Microsoft MSAL Python library. Includes an aiohttp server example.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""AsyncIO based OAuth Authorization Code Flow using the Microsoft MSAL Python library. | |
The AsyncMSAL class contains more info to perform OAuth & get the required tokens. | |
Once you have the OAuth tokens store in the session, you are free to make requests | |
(typically from an aiohttp server's inside a request) | |
For more info on Authorization Code flow, refer to https://auth0.com/docs/flows/authorization-code-flow | |
""" | |
import asyncio | |
import json | |
from functools import partial, wraps | |
from aiohttp import web | |
from aiohttp.client import ClientSession, _RequestContextManager | |
from msal import ConfidentialClientApplication, SerializableTokenCache | |
# Store your tokens etc in ENV (optional) | |
ENV = None | |
HTTP_GET = "get" | |
HTTP_POST = "post" | |
HTTP_PUT = "put" | |
HTTP_PATCH = "patch" | |
HTTP_DELETE = "delete" | |
HTTP_ALLOWED = [HTTP_GET, HTTP_POST, HTTP_PUT, HTTP_PATCH, HTTP_DELETE] | |
MY_SCOPE = ["User.Read", "User.Read.All"] | |
def async_wrap(func): | |
"""Wrap a function doing I/O to run in an executor thread.""" | |
@wraps(func) | |
async def run(*args, loop=None, executor=None, **kwargs): | |
if loop is None: | |
loop = asyncio.get_event_loop() | |
pfunc = partial(func, *args, **kwargs) | |
return await loop.run_in_executor(executor, pfunc) | |
return run | |
# These keys will be used on the aiohttp session | |
TOKEN_CACHE = "token_cache" | |
FLOW_CACHE = "flow_cache" | |
USER_EMAIL = "mail" | |
class AsyncMSAL: | |
""" | |
AsyncIO based OAuth using the Microsoft Authentication Library (MSAL) for Python. | |
Blocking MSAL functions are executed in the executor thread. | |
Use until such time as MSAL Python gets a true async version... | |
Tested with MSAL Python 1.13.0 | |
https://github.com/AzureAD/microsoft-authentication-library-for-python | |
AsyncMSAL is based on the following example app | |
https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76 | |
Use as follows: | |
Get the tokens via oauth | |
1. initiate_auth_code_flow | |
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow | |
The caller is expected to: | |
1.somehow store this content, typically inside the current session of the server, | |
2.guide the end user (i.e. resource owner) to visit that auth_uri, | |
typically with a redirect | |
3.and then relay this dict and subsequent auth response to | |
acquire_token_by_auth_code_flow(). | |
[1. and part of 3.] is stored by this class in the aiohttp_session | |
2. acquire_token_by_auth_code_flow | |
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.acquire_token_by_auth_code_flow | |
Now you are free to make requests (typically from an aiohttp server) | |
session = await get_session(request) | |
aiomsal = AsyncMSAL(session) | |
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res: | |
res = await res.json() | |
""" | |
aiohttp_session: ClientSession = None | |
client_id = ENV.SP_APP_ID if ENV else None | |
client_credential = ENV.SP_APP_PW if ENV else None | |
authority = ENV.SP_AUTHORITY if ENV else None | |
def __init__(self, session): | |
"""Create the application using the cache. | |
Based on: https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76 | |
session: an aiohttp_session.Session object | |
""" | |
self.session = session | |
self._token_cache = SerializableTokenCache() | |
# _load_token_cache | |
if session and session.get(TOKEN_CACHE): | |
self._token_cache.deserialize(session[TOKEN_CACHE]) | |
self.app = ConfidentialClientApplication( | |
client_id=self.client_id, | |
client_credential=self.client_credential, | |
authority=self.authority, # common/oauth2/v2.0/token' | |
validate_authority=False, | |
token_cache=self._token_cache, | |
) | |
def _save_token_cache(self): | |
"""Save the token cache if it changed.""" | |
if self._token_cache.has_state_changed: | |
self.session[TOKEN_CACHE] = self._token_cache.serialize() | |
def build_auth_code_flow(self, redirect_uri): | |
"""First step - Start the flow""" | |
if not self.session: | |
raise Exception("session required") | |
self.session[TOKEN_CACHE] = None | |
self.session[USER_EMAIL] = None | |
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow( | |
MY_SCOPE, | |
redirect_uri=redirect_uri, | |
) # https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow | |
return res["auth_uri"] | |
@async_wrap | |
def async_acquire_token_by_auth_code_flow(self, auth_response): | |
"""Second step - Acquire token.""" | |
# Assume we have it in the cache (added by /login) | |
# will raise keryerror if no cache | |
auth_code_flow = self.session.pop(FLOW_CACHE) | |
result = self.app.acquire_token_by_auth_code_flow(auth_code_flow, auth_response) | |
if "error" in result or "id_token_claims" not in result: | |
raise web.HTTPException(text=result) | |
self._save_token_cache() | |
self.session[USER_EMAIL] = result.get("id_token_claims").get( | |
"preferred_username" | |
) | |
@async_wrap | |
def async_get_token(self): | |
"""Acquire a token based on username.""" | |
accounts = self.app.get_accounts() | |
if accounts: | |
result = self.app.acquire_token_silent(scopes=MY_SCOPE, account=accounts[0]) | |
self._save_token_cache() | |
return result | |
return None | |
async def request(self, method, url, **kwargs): | |
"""Make a request to url using an oauth session | |
:param str url: url to send request to | |
:param str method: type of request (get/put/post/patch/delete) | |
:param kwargs: extra params to send to the request api | |
:return: Response of the request | |
:rtype: aiohttp.Response | |
""" | |
if not self.aiohttp_session: | |
AsyncMSAL.aiohttp_session = ClientSession(trust_env=True) | |
token = await self.async_get_token() | |
kwargs = kwargs.copy() | |
# Ensure headers exist & make a copy | |
kwargs["headers"] = headers = dict(kwargs.get("headers", {})) | |
headers["Authorization"] = "Bearer " + token["access_token"] | |
assert method in HTTP_ALLOWED, "Method must be one of the allowed ones" | |
if method == HTTP_GET: | |
kwargs.setdefault("allow_redirects", True) | |
elif method in [HTTP_POST, HTTP_PUT, HTTP_PATCH]: | |
headers["Content-type"] = "application/json" | |
if "data" in kwargs: | |
kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json | |
response = await self.aiohttp_session.request(method, url, **kwargs) | |
return response | |
def get(self, url, **kwargs): | |
"""GET Request.""" | |
return _RequestContextManager(self.request(HTTP_GET, url, **kwargs)) | |
def post(self, url, **kwargs): | |
"""POST request.""" | |
return _RequestContextManager(self.request(HTTP_POST, url, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""async_msal example server.""" | |
from aiohttp import web | |
from aiohttp_session import get_session, new_session, setup | |
from aiohttp_session.cookie_storage import EncryptedCookieStorage | |
from .msal_async import AsyncMSAL | |
ROUTES = web.RouteTableDef() | |
SESSION_REDIRECT = "session_redirect" | |
@ROUTES.get("/user/info") | |
async def user_info(request): | |
"""Example route to get info from MS Graph API""" | |
session = await get_session(request) | |
aiomsal = AsyncMSAL(session) | |
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res: | |
res = await res.json() | |
return web.json_response(res) | |
@ROUTES.get("/user/login/{redirect:.+$}") | |
async def user_login(request): | |
"""Start the user Login""" | |
session = await new_session(request) | |
session[SESSION_REDIRECT] = request.match_info.get( | |
SESSION_REDIRECT, session.get(SESSION_REDIRECT, "") | |
) | |
aiomsal = AsyncMSAL(session) | |
redir = aiomsal.build_auth_code_flow( | |
redirect_uri="https://mysite.com/user/authorized" | |
) | |
# Redirect user to sign in | |
return web.HTTPFound(redir) | |
@ROUTES.get("/user/authorized") | |
async def user_authorized(request: web.Request): | |
"""Process return flow after login.""" | |
session = await get_session(request) | |
# build a plain dict from the aiohttp server request's url parameters | |
auth_response = dict(request.rel_url.query.items()) | |
aiomsal = AsyncMSAL(session) | |
try: | |
await aiomsal.async_acquire_token_by_auth_code_flow(auth_response) | |
except Exception as err: # pylint: disable=broad-except | |
print("<b>Could not get token</b> - async_acquire_token_by_auth_code_flow", err) | |
raise | |
# Redirect user to local site | |
redirect = session.pop(SESSION_REDIRECT, "") or "/user/info" | |
return web.HTTPFound(f"/{redirect}") | |
def main(): | |
"""Main web server.""" | |
app = web.Application() | |
setup(app, EncryptedCookieStorage(b"Thirty two length bytes key.")) | |
app.add_routes(ROUTES) | |
web.run_app(app) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ok, that is an entirely different issue then. Glad you found a solution. The standard aiohtto_session also have redis & memcached Storage options that works in a similar fashion (storing only a key in the cookie and the data on the server)