Skip to content

Instantly share code, notes, and snippets.

@juztas
Created March 16, 2026 22:29
Show Gist options
  • Select an option

  • Save juztas/e6ec4e75890b341fba75bc2be6aed991 to your computer and use it in GitHub Desktop.

Select an option

Save juztas/e6ec4e75890b341fba75bc2be6aed991 to your computer and use it in GitHub Desktop.
facility_auth.py
class GlobusValidator:
"""Globus access token validator"""
def __init__(self):
cfg = getUserConfig()
self.allowed_usernames = cfg.get("allowed_usernames", {})
self.globus_id = cfg.get("globus_id")
self.globus_secret = cfg.get("globus_secret")
self.globus_map = cfg.get("globus_map", {})
self.globus_iss = cfg.get("globus_iss", "https://auth.globus.org")
self.globus_client = None
if not self.globus_id or not self.globus_secret:
print("Globus credentials not configured, Globus token validation will fail")
else:
self.globus_client = globus_sdk.ConfidentialAppAuthClient(self.globus_id, self.globus_secret)
def validate(self, token: str) -> dict:
"""Validate token and return claims"""
if not self.globus_id or not self.globus_secret or not self.globus_client:
raise HTTPException(status_code=403, detail="Invalid Globus token")
try:
introspect = self.globus_client.oauth2_token_introspect(token, include="session_info,identity_set_detail")
if not introspect.get("active"):
print("Globus token is not active")
raise HTTPException(403, "Invalid Globus token")
if introspect.get("exp") and utc_timestamp() > introspect["exp"]:
print("Globus token has expired")
raise HTTPException(403, "Token expired")
if introspect.get("nbf") and utc_timestamp() < introspect["nbf"]:
print("Globus token not valid yet")
raise HTTPException(403, "Token not valid yet")
if introspect.get("iss") != self.globus_iss:
print(f"Globus token has invalid issuer: {introspect.get('iss')}")
raise HTTPException(403, "Invalid token issuer")
session = introspect.get("session_info", {}).get("authentications", {})
# Find any identity in the session that is whitelisted
accepted_sub = None
for sub in session:
print(sub)
if sub in self.globus_map:
accepted_sub = sub
print(f"Accepted Globus identity: {sub} mapped to {self.globus_map[sub]}")
break
if not accepted_sub:
raise HTTPException(
status_code=403,
detail="User not authorized (not in whitelist)"
)
# Retrieve human-readable identity info
identity_set = introspect.get("identity_set_detail", [])
for identity in identity_set:
if identity["sub"] == accepted_sub:
return {
"sub": accepted_sub,
"preferred_username": self.globus_map[accepted_sub],
"name": identity.get("name", self.globus_map[accepted_sub]),
}
# fallback if identity metadata missing
return {
"sub": accepted_sub,
"preferred_username": self.globus_map[accepted_sub],
"name": self.globus_map[accepted_sub],
}
except Exception as ex:
print(f"Exception during Globus token validation: {ex}")
traceback.print_exc()
pass
raise HTTPException(status_code=403, detail="Invalid Globus token")
class KeycloakValidator:
"""Keycloak JWT token validator"""
def __init__(self):
cfg = getUserConfig()
self.issuer = cfg["issuer"]
self.client_id = cfg["client_id"]
self.allowed_usernames = cfg.get("allowed_usernames", {})
self.jwks_url = self.issuer.rstrip("/") + "/protocol/openid-connect/certs"
self._jwks = None
self.last_jwks_load = None
self.refresh_jwks_interval = cfg.get("refresh_jwks_interval", 3600)
self._load_jwks()
def _load_jwks(self):
try:
if not self._jwks:
resp = httpget(self.jwks_url)
self._jwks = resp.json()
self.last_jwks_load = utc_timestamp()
elif utc_timestamp() - self.last_jwks_load > self.refresh_jwks_interval:
resp = httpget(self.jwks_url)
self._jwks = resp.json()
self.last_jwks_load = utc_timestamp()
except Exception as ex:
self._jwks = None
self.last_jwks_load = None
raise RuntimeError(
f"Failed to load JWKS from {self.jwks_url}: {ex}"
) from ex
return self._jwks
def validate(self, token: str) -> dict:
"""Validate token and return claims"""
try:
claims = jwt.decode(
token,
self._load_jwks(),
algorithms=["RS256"],
audience=self.client_id,
issuer=self.issuer,
)
except ExpiredSignatureError as ex:
raise HTTPException(status_code=401, detail="Token expired") from ex
except JWTError as exc:
raise HTTPException(
status_code=403, detail=f"Invalid token: {exc}"
) from exc
username = claims.get("preferred_username")
if not username:
raise HTTPException(
status_code=403, detail="preferred_username missing in token"
)
if self.allowed_usernames and username not in self.allowed_usernames:
raise HTTPException(
status_code=403, detail=f"User {username} is not authorized"
)
return claims
class FakeUserDatabase:
"""In-memory user database. Used for development and testing purposes only."""
def __init__(self):
self.users = {}
cfg = getUserConfig()
for user, userdict in cfg.get("users", {}).items():
self.users[user] = account_models.User(
id=user,
name=userdict["name"],
api_key=userdict["api_key"],
client_ip=userdict["client_ip"],
)
def validate(self, token: str) -> dict:
"""Validate token and return claims"""
for user in self.users.values():
if user.api_key == token:
return {
"sub": user.id,
"preferred_username": user.id,
"name": user.name,
}
raise HTTPException(status_code=403, detail="Invalid token")
# This will need fully support AmSC users in the future and parse JWT
# For now we are keeping it simple
# pylint: disable=unused-argument
class UserDatabase:
"""Keycloak-backed user database"""
def __init__(self):
if os.environ.get("IRI_DEV_MODE", "false").lower() == "true":
self.validator = FakeUserDatabase()
else:
self.validator = FederatedValidator([KeycloakValidator(), GlobusValidator()])
async def get_current_user(self, api_key: str, client_ip: str | None) -> str:
"""Return current user ID based on API key"""
token = extract_api_key(api_key)
claims = self.validator.validate(token)
return claims["preferred_username"]
async def get_user(
self, user_id: str, api_key: str, client_ip: str | None
) -> account_models.User:
"""Return user object for given user ID and API key"""
token = extract_api_key(api_key)
claims = self.validator.validate(token)
if claims["preferred_username"] != user_id:
raise HTTPException(status_code=403, detail="User mismatch")
return account_models.User(
id=claims["preferred_username"],
name=claims["name"],
api_key="",
client_ip=client_ip,
)
class AuthMixin(AuthenticatedAdapter, UserDatabase):
"""Mixin class to provide authentication methods required by IRI adapters"""
get_current_user = UserDatabase.get_current_user
get_user = UserDatabase.get_user
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment