Created
March 16, 2026 22:29
-
-
Save juztas/e6ec4e75890b341fba75bc2be6aed991 to your computer and use it in GitHub Desktop.
facility_auth.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class GlobusValidator: | |
| """Globus access token validator""" | |
| def __init__(self): | |
| cfg = getUserConfig() | |
| self.allowed_usernames = cfg.get("allowed_usernames", {}) | |
| self.globus_id = cfg.get("globus_id") | |
| self.globus_secret = cfg.get("globus_secret") | |
| self.globus_map = cfg.get("globus_map", {}) | |
| self.globus_iss = cfg.get("globus_iss", "https://auth.globus.org") | |
| self.globus_client = None | |
| if not self.globus_id or not self.globus_secret: | |
| print("Globus credentials not configured, Globus token validation will fail") | |
| else: | |
| self.globus_client = globus_sdk.ConfidentialAppAuthClient(self.globus_id, self.globus_secret) | |
| def validate(self, token: str) -> dict: | |
| """Validate token and return claims""" | |
| if not self.globus_id or not self.globus_secret or not self.globus_client: | |
| raise HTTPException(status_code=403, detail="Invalid Globus token") | |
| try: | |
| introspect = self.globus_client.oauth2_token_introspect(token, include="session_info,identity_set_detail") | |
| if not introspect.get("active"): | |
| print("Globus token is not active") | |
| raise HTTPException(403, "Invalid Globus token") | |
| if introspect.get("exp") and utc_timestamp() > introspect["exp"]: | |
| print("Globus token has expired") | |
| raise HTTPException(403, "Token expired") | |
| if introspect.get("nbf") and utc_timestamp() < introspect["nbf"]: | |
| print("Globus token not valid yet") | |
| raise HTTPException(403, "Token not valid yet") | |
| if introspect.get("iss") != self.globus_iss: | |
| print(f"Globus token has invalid issuer: {introspect.get('iss')}") | |
| raise HTTPException(403, "Invalid token issuer") | |
| session = introspect.get("session_info", {}).get("authentications", {}) | |
| # Find any identity in the session that is whitelisted | |
| accepted_sub = None | |
| for sub in session: | |
| print(sub) | |
| if sub in self.globus_map: | |
| accepted_sub = sub | |
| print(f"Accepted Globus identity: {sub} mapped to {self.globus_map[sub]}") | |
| break | |
| if not accepted_sub: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="User not authorized (not in whitelist)" | |
| ) | |
| # Retrieve human-readable identity info | |
| identity_set = introspect.get("identity_set_detail", []) | |
| for identity in identity_set: | |
| if identity["sub"] == accepted_sub: | |
| return { | |
| "sub": accepted_sub, | |
| "preferred_username": self.globus_map[accepted_sub], | |
| "name": identity.get("name", self.globus_map[accepted_sub]), | |
| } | |
| # fallback if identity metadata missing | |
| return { | |
| "sub": accepted_sub, | |
| "preferred_username": self.globus_map[accepted_sub], | |
| "name": self.globus_map[accepted_sub], | |
| } | |
| except Exception as ex: | |
| print(f"Exception during Globus token validation: {ex}") | |
| traceback.print_exc() | |
| pass | |
| raise HTTPException(status_code=403, detail="Invalid Globus token") | |
| class KeycloakValidator: | |
| """Keycloak JWT token validator""" | |
| def __init__(self): | |
| cfg = getUserConfig() | |
| self.issuer = cfg["issuer"] | |
| self.client_id = cfg["client_id"] | |
| self.allowed_usernames = cfg.get("allowed_usernames", {}) | |
| self.jwks_url = self.issuer.rstrip("/") + "/protocol/openid-connect/certs" | |
| self._jwks = None | |
| self.last_jwks_load = None | |
| self.refresh_jwks_interval = cfg.get("refresh_jwks_interval", 3600) | |
| self._load_jwks() | |
| def _load_jwks(self): | |
| try: | |
| if not self._jwks: | |
| resp = httpget(self.jwks_url) | |
| self._jwks = resp.json() | |
| self.last_jwks_load = utc_timestamp() | |
| elif utc_timestamp() - self.last_jwks_load > self.refresh_jwks_interval: | |
| resp = httpget(self.jwks_url) | |
| self._jwks = resp.json() | |
| self.last_jwks_load = utc_timestamp() | |
| except Exception as ex: | |
| self._jwks = None | |
| self.last_jwks_load = None | |
| raise RuntimeError( | |
| f"Failed to load JWKS from {self.jwks_url}: {ex}" | |
| ) from ex | |
| return self._jwks | |
| def validate(self, token: str) -> dict: | |
| """Validate token and return claims""" | |
| try: | |
| claims = jwt.decode( | |
| token, | |
| self._load_jwks(), | |
| algorithms=["RS256"], | |
| audience=self.client_id, | |
| issuer=self.issuer, | |
| ) | |
| except ExpiredSignatureError as ex: | |
| raise HTTPException(status_code=401, detail="Token expired") from ex | |
| except JWTError as exc: | |
| raise HTTPException( | |
| status_code=403, detail=f"Invalid token: {exc}" | |
| ) from exc | |
| username = claims.get("preferred_username") | |
| if not username: | |
| raise HTTPException( | |
| status_code=403, detail="preferred_username missing in token" | |
| ) | |
| if self.allowed_usernames and username not in self.allowed_usernames: | |
| raise HTTPException( | |
| status_code=403, detail=f"User {username} is not authorized" | |
| ) | |
| return claims | |
| class FakeUserDatabase: | |
| """In-memory user database. Used for development and testing purposes only.""" | |
| def __init__(self): | |
| self.users = {} | |
| cfg = getUserConfig() | |
| for user, userdict in cfg.get("users", {}).items(): | |
| self.users[user] = account_models.User( | |
| id=user, | |
| name=userdict["name"], | |
| api_key=userdict["api_key"], | |
| client_ip=userdict["client_ip"], | |
| ) | |
| def validate(self, token: str) -> dict: | |
| """Validate token and return claims""" | |
| for user in self.users.values(): | |
| if user.api_key == token: | |
| return { | |
| "sub": user.id, | |
| "preferred_username": user.id, | |
| "name": user.name, | |
| } | |
| raise HTTPException(status_code=403, detail="Invalid token") | |
| # This will need fully support AmSC users in the future and parse JWT | |
| # For now we are keeping it simple | |
| # pylint: disable=unused-argument | |
| class UserDatabase: | |
| """Keycloak-backed user database""" | |
| def __init__(self): | |
| if os.environ.get("IRI_DEV_MODE", "false").lower() == "true": | |
| self.validator = FakeUserDatabase() | |
| else: | |
| self.validator = FederatedValidator([KeycloakValidator(), GlobusValidator()]) | |
| async def get_current_user(self, api_key: str, client_ip: str | None) -> str: | |
| """Return current user ID based on API key""" | |
| token = extract_api_key(api_key) | |
| claims = self.validator.validate(token) | |
| return claims["preferred_username"] | |
| async def get_user( | |
| self, user_id: str, api_key: str, client_ip: str | None | |
| ) -> account_models.User: | |
| """Return user object for given user ID and API key""" | |
| token = extract_api_key(api_key) | |
| claims = self.validator.validate(token) | |
| if claims["preferred_username"] != user_id: | |
| raise HTTPException(status_code=403, detail="User mismatch") | |
| return account_models.User( | |
| id=claims["preferred_username"], | |
| name=claims["name"], | |
| api_key="", | |
| client_ip=client_ip, | |
| ) | |
| class AuthMixin(AuthenticatedAdapter, UserDatabase): | |
| """Mixin class to provide authentication methods required by IRI adapters""" | |
| get_current_user = UserDatabase.get_current_user | |
| get_user = UserDatabase.get_user |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment