Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Created September 27, 2019 20:18
Show Gist options
  • Save dmontagu/56923dafdb3fe1798bb1c2ff8b0ab8d8 to your computer and use it in GitHub Desktop.
Save dmontagu/56923dafdb3fe1798bb1c2ff8b0ab8d8 to your computer and use it in GitHub Desktop.
UserCreateT = TypeVar("UserCreateT", bound=UserCreate)
UserCreateRequestT = TypeVar("UserCreateRequestT", bound=UserCreateRequest)
UserInDBT = TypeVar("UserInDBT", bound=UserInDB)
UserUpdateT = TypeVar("UserUpdateT", bound=UserUpdate)
UserApiT = TypeVar("UserApiT", bound=UserBaseInDB)
UserOrmT = TypeVar("UserOrmT", bound=BaseUser)
logger = logging.getLogger(__name__)
class EndpointNames(StrEnum):
@classmethod
def validate_router(cls, router: APIRouter) -> None:
route_names = {route.name for route in router.routes if isinstance(route, Route)}
for endpoint_name in cls:
assert endpoint_name in route_names, f"Missing route: {endpoint_name.name}"
class AuthEndpointName(EndpointNames):
login = auto()
refresh = auto()
validate_token = auto()
logout = auto()
logout_all = auto()
register = auto()
read_self = auto()
update_self = auto()
class AdminAuthEndpointName(EndpointNames):
read_user = auto()
read_users = auto()
create_user = auto()
update_user = auto()
class BaseAuthRouterBuilder(Generic[UserCreateT, UserCreateRequestT, UserInDBT, UserUpdateT, UserApiT, UserOrmT]):
create_type: Type[UserCreateT]
create_request_type: Type[UserCreateRequestT]
in_db_type: Type[UserInDBT]
update_type: Type[UserUpdateT]
api_type: Type[UserApiT]
orm_type: Type[UserOrmT]
def __init_subclass__(cls) -> None:
# TODO: Validate relationships as appropriate, at least the columns and in_db type
pass
# ########################
# ##### Dependencies #####
# ########################
@classmethod
def get_user(cls, db: Session = Depends(get_db), jwt_user: JWTUser = Depends(get_jwt_user)) -> UserInDBT:
"""
Loads a more complete user model based on a database lookup
"""
return cls.load_jwt_user(db, jwt_user)
# ###################################
# ##### Crud and helper methods #####
# ###################################
@classmethod
def load_jwt_user(cls, db: Session, jwt_user: JWTUser) -> UserInDBT:
user = cls.read(db=db, user_id=jwt_user.user_id)
if user is None:
raise_auth_error(detail="User not found")
return user
@classmethod
def read(cls, db: Session, user_id: UserID) -> Optional[UserInDBT]:
db_user = db.query(cls.orm_type).filter(cls.orm_type.user_id == user_id).first()
user = cls.in_db_type(**db_user.dict()) if db_user is not None else None
return user
@classmethod
def authenticate(cls, db: Session, username: str, password: RawPassword) -> UserInDBT:
user: Optional[UserOrmT] = db.query(cls.orm_type).filter(cls.orm_type.username == username).first()
if user is None:
raise_auth_error(detail="User not found")
password_checker = get_auth_settings().password_checker
result = password_checker.check_sync(password, HashedPassword(user.hashed_password))
if not result.success:
raise_auth_error(detail="Incorrect password")
# TODO: handle result.requires_update (using background task, presumably?)
return cls.in_db_type(**user.dict())
@classmethod
def create_user(cls, db: Session, user_create_request: UserCreateRequestT, is_superuser: bool) -> UserInDBT:
password_checker = get_auth_settings().password_checker
hashed_password = password_checker.make_sync(user_create_request.password)
user_create = cls.create_type(
hashed_password=hashed_password, is_superuser=is_superuser, **user_create_request.dict()
)
user = cls.orm_type(**user_create.dict())
user = add_base(db, user)
return cls.in_db_type(**user.dict())
@classmethod
def update_user(cls, db: Session, user_id: UserID, update_request: UserUpdateT) -> UserInDBT:
user_update = cls.update_type(**update_request.dict(skip_defaults=True))
update_dict: Dict[str, Any] = user_update.dict(skip_defaults=True)
if update_request.password:
password_checker = get_auth_settings().password_checker
hashed_password = password_checker.make_sync(update_request.password)
update_dict["hashed_password"] = hashed_password
if not update_dict:
raise APIResponseError("Nothing to update")
db_user: Optional[UserOrmT] = db.query(cls.orm_type).filter(cls.orm_type.user_id == user_id).first()
if not db_user:
raise APIResponseError("User not found", error_code=HTTP_404_NOT_FOUND)
for attribute, value in update_dict.items():
setattr(db_user, attribute, value)
with expected_integrity_error(db, detail="There was a conflict with an existing user"):
add_base(db, db_user)
return cls.in_db_type.from_orm(db_user)
# ##################
# ##### Routes #####
# ##################
@classmethod
def get_router(cls, include_admin_routes: bool) -> APIRouter:
auth_router = APIRouter()
admin_router = APIRouter()
api_type = cls.api_type
token_url = get_auth_settings().token_url
refresh_url = get_auth_settings().refresh_url
if not TYPE_CHECKING: # pragma: no cover
UserCreateRequestT = cls.create_request_type
UserApiT = cls.api_type
UserInDBT = cls.in_db_type
UserUpdateT = cls.update_type
@auth_router.post(
token_url, response_model=AuthTokens, response_model_skip_defaults=True, response_class=UncachedJSONResponse
)
def login(db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()) -> AuthTokens:
"""
OAuth2 compatible token login, get an access token for future requests
"""
user: UserInDBT = cls.authenticate(
db=db, username=form_data.username, password=RawPassword(form_data.password)
)
tokens = cls.login_flow(db=db, user=user, scopes=form_data.scopes)
response = tokens.to_response()
return response
@auth_router.get(
refresh_url,
response_model=AuthTokens,
response_model_skip_defaults=True,
response_class=UncachedJSONResponse,
)
def refresh(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> AuthTokens:
"""
Consume a refresh token to request a new access token
"""
tokens = cls.refresh_token_flow(db=db, token=token)
response = tokens.to_response()
return response
@auth_router.get(token_url + "/validate", response_model=UtilMessage, dependencies=[Depends(cls.get_user)])
def validate_token() -> UtilMessage:
return UtilMessage(detail="Token is valid for user")
@auth_router.get(token_url + "/logout", response_model=UtilMessage, response_class=UncachedJSONResponse)
def logout(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> UtilMessage:
"""
Invalidate the provided refresh token
"""
logout_flow(db=db, token=token)
return UtilMessage(detail="Logged out successfully")
@auth_router.get(token_url + "/logout/all", response_model=UtilMessage, response_class=UncachedJSONResponse)
def logout_all(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> UtilMessage:
"""
Invalidate all outstanding refresh tokens for the user
"""
logout_all_flow(db=db, token=token)
return UtilMessage(detail="Logged out all devices successfully")
@auth_router.post("/register", response_model=api_type)
def register(*, db: Session = Depends(get_db), request: AuthRegistrationRequest) -> UserInDBT:
"""
Create new user without the need to be logged in.
"""
if not get_auth_settings().users_open_registration:
raise_permissions_error(detail="User registration is not yet open")
user_create = cls.create_request_type(
username=request.username, password=request.password, is_superuser=False
)
with expected_integrity_error(db, detail="This username is already in use"):
user: UserInDBT = cls.create_user(db=db, user_create_request=user_create, is_superuser=False)
return user
@auth_router.get("/self", response_model=api_type)
def read_self(user: UserInDBT = Depends(cls.get_user)) -> UserInDBT:
return user
@auth_router.patch("/self", response_model=api_type)
def update_self(
*, db: Session = Depends(get_db), jwt_user: JWTUser = Depends(get_jwt_user), update_request: UserUpdateT
) -> UserInDBT:
"""
Update a user.
"""
user: UserInDBT = cls.update_user(db=db, user_id=jwt_user.user_id, update_request=update_request)
return user
@admin_router.get("/users/{user_id}", response_model=api_type)
def read_user(*, db: Session = Depends(get_db), user_id: UserID) -> UserApiT:
"""
Get a specific user by id.
"""
user = cls.read(db=db, user_id=user_id)
if user is None:
raise APIResponseError("User not found", error_code=HTTP_404_NOT_FOUND)
return user
@admin_router.get("/users", response_model=List[api_type]) # type: ignore
def read_users(db: Session = Depends(get_db), skip: int = 0, limit: int = 100) -> List[UserOrmT]:
"""
Retrieve users.
"""
result = db.query(cls.orm_type).offset(skip).limit(limit).all()
return result
@admin_router.post("/users", response_model=api_type)
def create_user(
*, db: Session = Depends(get_db), user_create_request: UserCreateRequestT, is_superuser: bool = False
) -> UserInDBT:
"""
Create new user.
"""
with expected_integrity_error(db, detail="This username is already in use"):
user: UserInDBT = cls.create_user(
db=db, user_create_request=user_create_request, is_superuser=is_superuser
)
return user
@admin_router.patch("/users/{user_id}", response_model=api_type)
def update_user(*, db: Session = Depends(get_db), user_id: UserID, update_request: UserUpdateT) -> UserInDBT:
"""
Update a user.
"""
user: UserInDBT = cls.update_user(db=db, user_id=user_id, update_request=update_request)
return user
AuthEndpointName.validate_router(auth_router)
AdminAuthEndpointName.validate_router(admin_router)
router = APIRouter()
router.include_router(auth_router, prefix="", tags=["auth"])
if include_admin_routes:
router.include_router(admin_router, prefix="/admin", dependencies=require_superuser(), tags=["admin-auth"])
return router
@classmethod
def login_flow(cls, db: Session, user: UserInDBT, scopes: List[str]) -> "TokenPair":
token_pair = _create_tokens(db=db, user=user, scopes=scopes)
return token_pair
@classmethod
def refresh_token_flow(cls, db: Session, token: "Token") -> "TokenPair":
"""
Implements the refresh token flow by performing the following steps:
* Delete the provided refresh token
* Generate and return a new refresh (and access) token
"""
user: Optional[UserInDBT] = cls.read(db=db, user_id=token.payload.user_id)
if user is None:
raise_auth_error(detail="User not found; try logging in again")
_consume_refresh_token(db=db, token=token)
token_pair = _create_tokens(db=db, user=user, scopes=token.payload.scopes)
return token_pair
@classmethod
def setup_first_superuser(cls, engine: sa.engine.Engine) -> None:
settings = get_auth_settings()
username = settings.first_superuser
password = RawPassword(settings.first_superuser_password) if settings.first_superuser_password else None
assert username, f"Invalid superuser username: {username}"
assert password, f"Invalid superuser password: {password}"
user_create = UserCreateRequest(username=username, password=password)
session = get_sessionmaker_for_engine(engine)()
try:
cls.create_user(db=session, user_create_request=user_create, is_superuser=True)
logger.info("First superuser created.")
except IntegrityError:
logger.info("First superuser already exists.")
finally:
session.close()
def remove_expired_tokens(db: Session) -> int:
"""
Returns the number of removed expired tokens (e.g., for logging)
"""
now = get_epoch()
filtered = db.query(RefreshToken).filter(RefreshToken.exp < now)
n_expired_tokens = filtered.count()
filtered.delete()
db.commit()
return n_expired_tokens
def get_epoch() -> int:
"""
Returns the number of seconds since the epoch
"""
return timegm(datetime.utcnow().utctimetuple())
def logout_flow(db: Session, token: "Token") -> None:
"""
Deletes the provided refresh token
"""
_consume_refresh_token(db=db, token=token)
def logout_all_flow(db: Session, token: "Token") -> None:
"""
Deletes all refresh tokens for the token's specified user
"""
_consume_refresh_token(db=db, token=token)
db.query(RefreshToken).filter(RefreshToken.user_id == token.payload.user_id).delete(synchronize_session=False)
db.commit()
def _create_tokens(db: Session, user: UserInDB, scopes: List[str]) -> "TokenPair":
token_pair = generate_tokens(user=user, scopes=scopes)
refresh_token = RefreshToken(
token=token_pair.refresh.encoded, user_id=token_pair.refresh.payload.user_id, exp=token_pair.refresh.payload.exp
)
add_base(db, refresh_token)
return token_pair
def _consume_refresh_token(db: Session, token: "Token") -> None:
"""
First validates that it hasn't already been used, then deletes it
"""
if token.payload.sub != refresh_token_jwt_subject:
raise_auth_error(detail="Provided token was not a refresh token")
# Validate that the token is in the database, even though it passed decoding
# This is critical to ensure previously used / logged out tokens don't work
refresh_token = db.query(RefreshToken).filter(RefreshToken.token == token.encoded).first()
if refresh_token is None:
raise_auth_error(detail="Provided refresh token was invalid")
db.delete(refresh_token)
db.commit()
RouterBuilderType = Type[BaseAuthRouterBuilder[Any, Any, Any, Any, Any, Any]]
def get_auth_app(
router_builder: RouterBuilderType,
include_admin_routes: bool,
openapi_url: Optional[str] = None,
docs_url: Optional[str] = None,
redoc_url: Optional[str] = None,
swagger_ui_oauth2_redirect_url: Optional[str] = None,
**fastapi_kwargs: Any,
) -> FastAPI:
api_router = router_builder.get_router(include_admin_routes)
fastapi_kwargs.setdefault("debug", get_api_settings().debug)
auth_app = FastAPI(
openapi_url=openapi_url,
docs_url=docs_url,
redoc_url=redoc_url,
swagger_ui_oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
**fastapi_kwargs,
)
auth_app.include_router(api_router)
return auth_app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment