-
-
Save kellerza/8aad3952086b827a9f32516373df1623 to your computer and use it in GitHub Desktop.
"""AsyncIO based OAuth Authorization Code Flow using the Microsoft MSAL Python library. | |
The AsyncMSAL class contains more info to perform OAuth & get the required tokens. | |
Once you have the OAuth tokens store in the session, you are free to make requests | |
(typically from an aiohttp server's inside a request) | |
For more info on Authorization Code flow, refer to https://auth0.com/docs/flows/authorization-code-flow | |
""" | |
import asyncio | |
import json | |
from functools import partial, wraps | |
from aiohttp import web | |
from aiohttp.client import ClientSession, _RequestContextManager | |
from msal import ConfidentialClientApplication, SerializableTokenCache | |
# Store your tokens etc in ENV (optional) | |
ENV = None | |
HTTP_GET = "get" | |
HTTP_POST = "post" | |
HTTP_PUT = "put" | |
HTTP_PATCH = "patch" | |
HTTP_DELETE = "delete" | |
HTTP_ALLOWED = [HTTP_GET, HTTP_POST, HTTP_PUT, HTTP_PATCH, HTTP_DELETE] | |
MY_SCOPE = ["User.Read", "User.Read.All"] | |
def async_wrap(func): | |
"""Wrap a function doing I/O to run in an executor thread.""" | |
@wraps(func) | |
async def run(*args, loop=None, executor=None, **kwargs): | |
if loop is None: | |
loop = asyncio.get_event_loop() | |
pfunc = partial(func, *args, **kwargs) | |
return await loop.run_in_executor(executor, pfunc) | |
return run | |
# These keys will be used on the aiohttp session | |
TOKEN_CACHE = "token_cache" | |
FLOW_CACHE = "flow_cache" | |
USER_EMAIL = "mail" | |
class AsyncMSAL: | |
""" | |
AsyncIO based OAuth using the Microsoft Authentication Library (MSAL) for Python. | |
Blocking MSAL functions are executed in the executor thread. | |
Use until such time as MSAL Python gets a true async version... | |
Tested with MSAL Python 1.13.0 | |
https://github.com/AzureAD/microsoft-authentication-library-for-python | |
AsyncMSAL is based on the following example app | |
https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76 | |
Use as follows: | |
Get the tokens via oauth | |
1. initiate_auth_code_flow | |
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow | |
The caller is expected to: | |
1.somehow store this content, typically inside the current session of the server, | |
2.guide the end user (i.e. resource owner) to visit that auth_uri, | |
typically with a redirect | |
3.and then relay this dict and subsequent auth response to | |
acquire_token_by_auth_code_flow(). | |
[1. and part of 3.] is stored by this class in the aiohttp_session | |
2. acquire_token_by_auth_code_flow | |
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.acquire_token_by_auth_code_flow | |
Now you are free to make requests (typically from an aiohttp server) | |
session = await get_session(request) | |
aiomsal = AsyncMSAL(session) | |
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res: | |
res = await res.json() | |
""" | |
aiohttp_session: ClientSession = None | |
client_id = ENV.SP_APP_ID if ENV else None | |
client_credential = ENV.SP_APP_PW if ENV else None | |
authority = ENV.SP_AUTHORITY if ENV else None | |
def __init__(self, session): | |
"""Create the application using the cache. | |
Based on: https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76 | |
session: an aiohttp_session.Session object | |
""" | |
self.session = session | |
self._token_cache = SerializableTokenCache() | |
# _load_token_cache | |
if session and session.get(TOKEN_CACHE): | |
self._token_cache.deserialize(session[TOKEN_CACHE]) | |
self.app = ConfidentialClientApplication( | |
client_id=self.client_id, | |
client_credential=self.client_credential, | |
authority=self.authority, # common/oauth2/v2.0/token' | |
validate_authority=False, | |
token_cache=self._token_cache, | |
) | |
def _save_token_cache(self): | |
"""Save the token cache if it changed.""" | |
if self._token_cache.has_state_changed: | |
self.session[TOKEN_CACHE] = self._token_cache.serialize() | |
def build_auth_code_flow(self, redirect_uri): | |
"""First step - Start the flow""" | |
if not self.session: | |
raise Exception("session required") | |
self.session[TOKEN_CACHE] = None | |
self.session[USER_EMAIL] = None | |
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow( | |
MY_SCOPE, | |
redirect_uri=redirect_uri, | |
) # https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow | |
return res["auth_uri"] | |
@async_wrap | |
def async_acquire_token_by_auth_code_flow(self, auth_response): | |
"""Second step - Acquire token.""" | |
# Assume we have it in the cache (added by /login) | |
# will raise keryerror if no cache | |
auth_code_flow = self.session.pop(FLOW_CACHE) | |
result = self.app.acquire_token_by_auth_code_flow(auth_code_flow, auth_response) | |
if "error" in result or "id_token_claims" not in result: | |
raise web.HTTPException(text=result) | |
self._save_token_cache() | |
self.session[USER_EMAIL] = result.get("id_token_claims").get( | |
"preferred_username" | |
) | |
@async_wrap | |
def async_get_token(self): | |
"""Acquire a token based on username.""" | |
accounts = self.app.get_accounts() | |
if accounts: | |
result = self.app.acquire_token_silent(scopes=MY_SCOPE, account=accounts[0]) | |
self._save_token_cache() | |
return result | |
return None | |
async def request(self, method, url, **kwargs): | |
"""Make a request to url using an oauth session | |
:param str url: url to send request to | |
:param str method: type of request (get/put/post/patch/delete) | |
:param kwargs: extra params to send to the request api | |
:return: Response of the request | |
:rtype: aiohttp.Response | |
""" | |
if not self.aiohttp_session: | |
AsyncMSAL.aiohttp_session = ClientSession(trust_env=True) | |
token = await self.async_get_token() | |
kwargs = kwargs.copy() | |
# Ensure headers exist & make a copy | |
kwargs["headers"] = headers = dict(kwargs.get("headers", {})) | |
headers["Authorization"] = "Bearer " + token["access_token"] | |
assert method in HTTP_ALLOWED, "Method must be one of the allowed ones" | |
if method == HTTP_GET: | |
kwargs.setdefault("allow_redirects", True) | |
elif method in [HTTP_POST, HTTP_PUT, HTTP_PATCH]: | |
headers["Content-type"] = "application/json" | |
if "data" in kwargs: | |
kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json | |
response = await self.aiohttp_session.request(method, url, **kwargs) | |
return response | |
def get(self, url, **kwargs): | |
"""GET Request.""" | |
return _RequestContextManager(self.request(HTTP_GET, url, **kwargs)) | |
def post(self, url, **kwargs): | |
"""POST request.""" | |
return _RequestContextManager(self.request(HTTP_POST, url, **kwargs)) |
"""async_msal example server.""" | |
from aiohttp import web | |
from aiohttp_session import get_session, new_session, setup | |
from aiohttp_session.cookie_storage import EncryptedCookieStorage | |
from .msal_async import AsyncMSAL | |
ROUTES = web.RouteTableDef() | |
SESSION_REDIRECT = "session_redirect" | |
@ROUTES.get("/user/info") | |
async def user_info(request): | |
"""Example route to get info from MS Graph API""" | |
session = await get_session(request) | |
aiomsal = AsyncMSAL(session) | |
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res: | |
res = await res.json() | |
return web.json_response(res) | |
@ROUTES.get("/user/login/{redirect:.+$}") | |
async def user_login(request): | |
"""Start the user Login""" | |
session = await new_session(request) | |
session[SESSION_REDIRECT] = request.match_info.get( | |
SESSION_REDIRECT, session.get(SESSION_REDIRECT, "") | |
) | |
aiomsal = AsyncMSAL(session) | |
redir = aiomsal.build_auth_code_flow( | |
redirect_uri="https://mysite.com/user/authorized" | |
) | |
# Redirect user to sign in | |
return web.HTTPFound(redir) | |
@ROUTES.get("/user/authorized") | |
async def user_authorized(request: web.Request): | |
"""Process return flow after login.""" | |
session = await get_session(request) | |
# build a plain dict from the aiohttp server request's url parameters | |
auth_response = dict(request.rel_url.query.items()) | |
aiomsal = AsyncMSAL(session) | |
try: | |
await aiomsal.async_acquire_token_by_auth_code_flow(auth_response) | |
except Exception as err: # pylint: disable=broad-except | |
print("<b>Could not get token</b> - async_acquire_token_by_auth_code_flow", err) | |
raise | |
# Redirect user to local site | |
redirect = session.pop(SESSION_REDIRECT, "") or "/user/info" | |
return web.HTTPFound(f"/{redirect}") | |
def main(): | |
"""Main web server.""" | |
app = web.Application() | |
setup(app, EncryptedCookieStorage(b"Thirty two length bytes key.")) | |
app.add_routes(ROUTES) | |
web.run_app(app) | |
if __name__ == "__main__": | |
main() |
Use with aiohttp 3.6.x (not 3.7.x, as this breaks aiohttp_session)
Hi @kellerza What do you mean by it breaking aiohttp_session ? I see the "token_cache" disappears from the session right after the redirect from user_authorized() - but can't figure out why that happens
There is an issue with 3.7 where a new session is returned for every request. You will see it as token_cache
disappearing, but as a matter of fact, the session is completely unique on every request (It's clear from the value of the cookie AIOHTTP_SESSION. You can use "preserve log" in most web browsers debug tools)
Just realized I use aiohttp 3.8 at the moment! see here why not aiohttp 3.7
I'm also on 3.8.1 but found that the cookie size exceeded 4k, that seemed to have caused the misbehaviour. Changing the aoihttp-session cookie storage to a keybased cookie, e.g https://github.com/zhangkaizhao/aiohttp-session-file fixed this
Ok, that is an entirely different issue then. Glad you found a solution. The standard aiohtto_session also have redis & memcached Storage options that works in a similar fashion (storing only a key in the cookie and the data on the server)
You can also use this library https://github.com/kellerza/aiohttp_msal
Use with aiohttp 3.6.x (not 3.7.x, as this breaks aiohttp-session)
EDIT 2022: You can use aiohttp >=3.8,<3.9