-
-
Save enginefeeder101/d97bc14d5a931366f9576fc4091835b8 to your computer and use it in GitHub Desktop.
Custom HomeAssistant auth provider
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import jwt | |
from urllib.parse import urlparse | |
import urllib.request | |
from collections import OrderedDict | |
from typing import Any, Dict, Optional, cast | |
import voluptuous as vol | |
from homeassistant.exceptions import HomeAssistantError | |
from homeassistant.core import callback | |
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow | |
from ..models import Credentials, UserMeta | |
_LOGGER = logging.getLogger(__name__) | |
CONF_PUBLIC_KEY = "public_key" | |
CONF_ALGORITHM = "algorithm" | |
CONF_COOKIE_NAME = "cookie_name" | |
CONF_REALM = "realm" | |
CONF_USERNAME_KEY = "username_key" | |
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend( | |
{ | |
vol.Required(CONF_PUBLIC_KEY): str, | |
vol.Optional(CONF_ALGORITHM, default="ES256"): str, | |
vol.Optional(CONF_REALM, default="Single Sign-On"): str, | |
vol.Optional(CONF_COOKIE_NAME, default="AccessToken"): str, | |
vol.Optional(CONF_USERNAME_KEY, default="username"): str | |
} | |
) | |
class InvalidAuthError(HomeAssistantError): | |
"""Raised when submitting invalid authentication""" | |
@AUTH_PROVIDERS.register("access_token") | |
class AccessTokenAuthProvider(AuthProvider): | |
"""Logs in users from an access token stored in an JWT cookie""" | |
DEFAULT_TITLE = "Access Token" | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.DEFAULT_TITLE = self.config[CONF_REALM] | |
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: | |
"""Return a flow to login.""" | |
assert context is not None | |
cookie_name = self.config[CONF_COOKIE_NAME] | |
cookies = context.get("cookies") | |
access_token = None | |
if cookie_name in cookies: | |
access_token = cookies[cookie_name] | |
return AccessTokenLoginFlow(self, access_token) | |
async def async_validate_access(self, access_token: str) -> None: | |
"""Validate an access token""" | |
key = self.config[CONF_PUBLIC_KEY] | |
alg = self.config[CONF_ALGORITHM] | |
username_key = self.config[CONF_USERNAME_KEY] | |
if access_token is None: | |
_LOGGER.info("Tried to authenticate when no access token was provided.") | |
raise InvalidAuthError("No access token present") | |
else: | |
try: | |
claim = jwt.decode(access_token, key, algorithms=[alg]) | |
except jwt.exceptions.InvalidTokenError: | |
raise InvalidAuthError("Invalid access token") | |
# Check if username_key is in the claim | |
if username_key in claim: | |
return claim | |
else: | |
raise InvalidAuthError("Username key missing in token") | |
async def async_get_or_create_credentials( | |
self, flow_result: Dict[str, str] | |
) -> Credentials: | |
"""Get credentials based on the flow result.""" | |
# Extracts the username from the JWT claim | |
username_key = self.config[CONF_USERNAME_KEY] | |
username = flow_result[username_key] | |
for credential in await self.async_credentials(): | |
if credential.data["username"] == username: | |
return credential | |
# Create new credentials | |
return self.async_create_credentials({"username": username}) | |
async def async_user_meta_for_credentials( | |
self, credentials: Credentials | |
) -> UserMeta: | |
"""Return extra user metadata for credentials. | |
Will be used to populate info when creating a new user. | |
""" | |
username = credentials.data["username"] | |
return UserMeta(name=username, is_active=True) | |
class AccessTokenLoginFlow(LoginFlow): | |
"""Handler for the login flow.""" | |
def __init__( | |
self, | |
auth_provider: AccessTokenAuthProvider, | |
access_token: None, | |
) -> None: | |
"""Initialize the login flow""" | |
super().__init__(auth_provider) | |
self._access_token = access_token | |
async def async_step_init( | |
self, user_input: Optional[Dict[str, str]] = None | |
) -> Dict[str, Any]: | |
"""Handle the step of the form.""" | |
errors = {} | |
if user_input is not None: | |
try: | |
result = await cast( | |
AccessTokenAuthProvider, self._auth_provider | |
).async_validate_access(self._access_token) | |
except InvalidAuthError: | |
errors["base"] = "invalid_auth" | |
if not errors: | |
return await self.async_finish(result) | |
return self.async_show_form( | |
step_id="init", data_schema=vol.Schema(OrderedDict()), errors=errors | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
homeassistant: | |
auth_providers: | |
- type: access_token | |
realm: Your Own Single Sign-On | |
cookie_name: SSOsessionJWT | |
username_key: userid | |
public_key: | | |
!secret | |
-----BEGIN PUBLIC KEY----- | |
MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQAZ+NnmWUUBt21XUOztH2ey7xIAsNc | |
LjAMmLw5yKjNrPOV/zm3poGFDd/xE8IldmWbkM5BSxUFOGp2I9/K1gFyQLEBfvGE | |
Snti6CGKPdUIhFfkTtja9dtG2lnVJ5evgk88mWo4ESlS8zgymJTOy+kFgDzwkHPf | |
DMo5baGSomE984VhzqM= | |
-----END PUBLIC KEY----- | |
algorithm: ES512 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ES512 | |
# Private key | |
openssl ecparam -genkey -name secp521r1 -noout -out private.pem | |
# Public key | |
openssl ec -in private.pem -pubout -out public.pem |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py | |
index b907598..2ed0709 100644 | |
--- a/homeassistant/components/auth/login_flow.py | |
+++ b/homeassistant/components/auth/login_flow.py | |
@@ -189,6 +189,7 @@ class LoginFlowIndexView(LoginFlowBaseView): | |
handler, # type: ignore[arg-type] | |
context={ | |
"ip_address": ip_address(request.remote), # type: ignore[arg-type] | |
+ "cookies": request.cookies, | |
"credential_only": data.get("type") == "link_user", | |
"redirect_uri": redirect_uri, | |
}, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment