Created
September 27, 2019 20:18
-
-
Save dmontagu/56923dafdb3fe1798bb1c2ff8b0ab8d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
UserCreateT = TypeVar("UserCreateT", bound=UserCreate) | |
UserCreateRequestT = TypeVar("UserCreateRequestT", bound=UserCreateRequest) | |
UserInDBT = TypeVar("UserInDBT", bound=UserInDB) | |
UserUpdateT = TypeVar("UserUpdateT", bound=UserUpdate) | |
UserApiT = TypeVar("UserApiT", bound=UserBaseInDB) | |
UserOrmT = TypeVar("UserOrmT", bound=BaseUser) | |
logger = logging.getLogger(__name__) | |
class EndpointNames(StrEnum): | |
@classmethod | |
def validate_router(cls, router: APIRouter) -> None: | |
route_names = {route.name for route in router.routes if isinstance(route, Route)} | |
for endpoint_name in cls: | |
assert endpoint_name in route_names, f"Missing route: {endpoint_name.name}" | |
class AuthEndpointName(EndpointNames): | |
login = auto() | |
refresh = auto() | |
validate_token = auto() | |
logout = auto() | |
logout_all = auto() | |
register = auto() | |
read_self = auto() | |
update_self = auto() | |
class AdminAuthEndpointName(EndpointNames): | |
read_user = auto() | |
read_users = auto() | |
create_user = auto() | |
update_user = auto() | |
class BaseAuthRouterBuilder(Generic[UserCreateT, UserCreateRequestT, UserInDBT, UserUpdateT, UserApiT, UserOrmT]): | |
create_type: Type[UserCreateT] | |
create_request_type: Type[UserCreateRequestT] | |
in_db_type: Type[UserInDBT] | |
update_type: Type[UserUpdateT] | |
api_type: Type[UserApiT] | |
orm_type: Type[UserOrmT] | |
def __init_subclass__(cls) -> None: | |
# TODO: Validate relationships as appropriate, at least the columns and in_db type | |
pass | |
# ######################## | |
# ##### Dependencies ##### | |
# ######################## | |
@classmethod | |
def get_user(cls, db: Session = Depends(get_db), jwt_user: JWTUser = Depends(get_jwt_user)) -> UserInDBT: | |
""" | |
Loads a more complete user model based on a database lookup | |
""" | |
return cls.load_jwt_user(db, jwt_user) | |
# ################################### | |
# ##### Crud and helper methods ##### | |
# ################################### | |
@classmethod | |
def load_jwt_user(cls, db: Session, jwt_user: JWTUser) -> UserInDBT: | |
user = cls.read(db=db, user_id=jwt_user.user_id) | |
if user is None: | |
raise_auth_error(detail="User not found") | |
return user | |
@classmethod | |
def read(cls, db: Session, user_id: UserID) -> Optional[UserInDBT]: | |
db_user = db.query(cls.orm_type).filter(cls.orm_type.user_id == user_id).first() | |
user = cls.in_db_type(**db_user.dict()) if db_user is not None else None | |
return user | |
@classmethod | |
def authenticate(cls, db: Session, username: str, password: RawPassword) -> UserInDBT: | |
user: Optional[UserOrmT] = db.query(cls.orm_type).filter(cls.orm_type.username == username).first() | |
if user is None: | |
raise_auth_error(detail="User not found") | |
password_checker = get_auth_settings().password_checker | |
result = password_checker.check_sync(password, HashedPassword(user.hashed_password)) | |
if not result.success: | |
raise_auth_error(detail="Incorrect password") | |
# TODO: handle result.requires_update (using background task, presumably?) | |
return cls.in_db_type(**user.dict()) | |
@classmethod | |
def create_user(cls, db: Session, user_create_request: UserCreateRequestT, is_superuser: bool) -> UserInDBT: | |
password_checker = get_auth_settings().password_checker | |
hashed_password = password_checker.make_sync(user_create_request.password) | |
user_create = cls.create_type( | |
hashed_password=hashed_password, is_superuser=is_superuser, **user_create_request.dict() | |
) | |
user = cls.orm_type(**user_create.dict()) | |
user = add_base(db, user) | |
return cls.in_db_type(**user.dict()) | |
@classmethod | |
def update_user(cls, db: Session, user_id: UserID, update_request: UserUpdateT) -> UserInDBT: | |
user_update = cls.update_type(**update_request.dict(skip_defaults=True)) | |
update_dict: Dict[str, Any] = user_update.dict(skip_defaults=True) | |
if update_request.password: | |
password_checker = get_auth_settings().password_checker | |
hashed_password = password_checker.make_sync(update_request.password) | |
update_dict["hashed_password"] = hashed_password | |
if not update_dict: | |
raise APIResponseError("Nothing to update") | |
db_user: Optional[UserOrmT] = db.query(cls.orm_type).filter(cls.orm_type.user_id == user_id).first() | |
if not db_user: | |
raise APIResponseError("User not found", error_code=HTTP_404_NOT_FOUND) | |
for attribute, value in update_dict.items(): | |
setattr(db_user, attribute, value) | |
with expected_integrity_error(db, detail="There was a conflict with an existing user"): | |
add_base(db, db_user) | |
return cls.in_db_type.from_orm(db_user) | |
# ################## | |
# ##### Routes ##### | |
# ################## | |
@classmethod | |
def get_router(cls, include_admin_routes: bool) -> APIRouter: | |
auth_router = APIRouter() | |
admin_router = APIRouter() | |
api_type = cls.api_type | |
token_url = get_auth_settings().token_url | |
refresh_url = get_auth_settings().refresh_url | |
if not TYPE_CHECKING: # pragma: no cover | |
UserCreateRequestT = cls.create_request_type | |
UserApiT = cls.api_type | |
UserInDBT = cls.in_db_type | |
UserUpdateT = cls.update_type | |
@auth_router.post( | |
token_url, response_model=AuthTokens, response_model_skip_defaults=True, response_class=UncachedJSONResponse | |
) | |
def login(db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()) -> AuthTokens: | |
""" | |
OAuth2 compatible token login, get an access token for future requests | |
""" | |
user: UserInDBT = cls.authenticate( | |
db=db, username=form_data.username, password=RawPassword(form_data.password) | |
) | |
tokens = cls.login_flow(db=db, user=user, scopes=form_data.scopes) | |
response = tokens.to_response() | |
return response | |
@auth_router.get( | |
refresh_url, | |
response_model=AuthTokens, | |
response_model_skip_defaults=True, | |
response_class=UncachedJSONResponse, | |
) | |
def refresh(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> AuthTokens: | |
""" | |
Consume a refresh token to request a new access token | |
""" | |
tokens = cls.refresh_token_flow(db=db, token=token) | |
response = tokens.to_response() | |
return response | |
@auth_router.get(token_url + "/validate", response_model=UtilMessage, dependencies=[Depends(cls.get_user)]) | |
def validate_token() -> UtilMessage: | |
return UtilMessage(detail="Token is valid for user") | |
@auth_router.get(token_url + "/logout", response_model=UtilMessage, response_class=UncachedJSONResponse) | |
def logout(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> UtilMessage: | |
""" | |
Invalidate the provided refresh token | |
""" | |
logout_flow(db=db, token=token) | |
return UtilMessage(detail="Logged out successfully") | |
@auth_router.get(token_url + "/logout/all", response_model=UtilMessage, response_class=UncachedJSONResponse) | |
def logout_all(db: Session = Depends(get_db), token: Token = Depends(get_headers_token_openapi)) -> UtilMessage: | |
""" | |
Invalidate all outstanding refresh tokens for the user | |
""" | |
logout_all_flow(db=db, token=token) | |
return UtilMessage(detail="Logged out all devices successfully") | |
@auth_router.post("/register", response_model=api_type) | |
def register(*, db: Session = Depends(get_db), request: AuthRegistrationRequest) -> UserInDBT: | |
""" | |
Create new user without the need to be logged in. | |
""" | |
if not get_auth_settings().users_open_registration: | |
raise_permissions_error(detail="User registration is not yet open") | |
user_create = cls.create_request_type( | |
username=request.username, password=request.password, is_superuser=False | |
) | |
with expected_integrity_error(db, detail="This username is already in use"): | |
user: UserInDBT = cls.create_user(db=db, user_create_request=user_create, is_superuser=False) | |
return user | |
@auth_router.get("/self", response_model=api_type) | |
def read_self(user: UserInDBT = Depends(cls.get_user)) -> UserInDBT: | |
return user | |
@auth_router.patch("/self", response_model=api_type) | |
def update_self( | |
*, db: Session = Depends(get_db), jwt_user: JWTUser = Depends(get_jwt_user), update_request: UserUpdateT | |
) -> UserInDBT: | |
""" | |
Update a user. | |
""" | |
user: UserInDBT = cls.update_user(db=db, user_id=jwt_user.user_id, update_request=update_request) | |
return user | |
@admin_router.get("/users/{user_id}", response_model=api_type) | |
def read_user(*, db: Session = Depends(get_db), user_id: UserID) -> UserApiT: | |
""" | |
Get a specific user by id. | |
""" | |
user = cls.read(db=db, user_id=user_id) | |
if user is None: | |
raise APIResponseError("User not found", error_code=HTTP_404_NOT_FOUND) | |
return user | |
@admin_router.get("/users", response_model=List[api_type]) # type: ignore | |
def read_users(db: Session = Depends(get_db), skip: int = 0, limit: int = 100) -> List[UserOrmT]: | |
""" | |
Retrieve users. | |
""" | |
result = db.query(cls.orm_type).offset(skip).limit(limit).all() | |
return result | |
@admin_router.post("/users", response_model=api_type) | |
def create_user( | |
*, db: Session = Depends(get_db), user_create_request: UserCreateRequestT, is_superuser: bool = False | |
) -> UserInDBT: | |
""" | |
Create new user. | |
""" | |
with expected_integrity_error(db, detail="This username is already in use"): | |
user: UserInDBT = cls.create_user( | |
db=db, user_create_request=user_create_request, is_superuser=is_superuser | |
) | |
return user | |
@admin_router.patch("/users/{user_id}", response_model=api_type) | |
def update_user(*, db: Session = Depends(get_db), user_id: UserID, update_request: UserUpdateT) -> UserInDBT: | |
""" | |
Update a user. | |
""" | |
user: UserInDBT = cls.update_user(db=db, user_id=user_id, update_request=update_request) | |
return user | |
AuthEndpointName.validate_router(auth_router) | |
AdminAuthEndpointName.validate_router(admin_router) | |
router = APIRouter() | |
router.include_router(auth_router, prefix="", tags=["auth"]) | |
if include_admin_routes: | |
router.include_router(admin_router, prefix="/admin", dependencies=require_superuser(), tags=["admin-auth"]) | |
return router | |
@classmethod | |
def login_flow(cls, db: Session, user: UserInDBT, scopes: List[str]) -> "TokenPair": | |
token_pair = _create_tokens(db=db, user=user, scopes=scopes) | |
return token_pair | |
@classmethod | |
def refresh_token_flow(cls, db: Session, token: "Token") -> "TokenPair": | |
""" | |
Implements the refresh token flow by performing the following steps: | |
* Delete the provided refresh token | |
* Generate and return a new refresh (and access) token | |
""" | |
user: Optional[UserInDBT] = cls.read(db=db, user_id=token.payload.user_id) | |
if user is None: | |
raise_auth_error(detail="User not found; try logging in again") | |
_consume_refresh_token(db=db, token=token) | |
token_pair = _create_tokens(db=db, user=user, scopes=token.payload.scopes) | |
return token_pair | |
@classmethod | |
def setup_first_superuser(cls, engine: sa.engine.Engine) -> None: | |
settings = get_auth_settings() | |
username = settings.first_superuser | |
password = RawPassword(settings.first_superuser_password) if settings.first_superuser_password else None | |
assert username, f"Invalid superuser username: {username}" | |
assert password, f"Invalid superuser password: {password}" | |
user_create = UserCreateRequest(username=username, password=password) | |
session = get_sessionmaker_for_engine(engine)() | |
try: | |
cls.create_user(db=session, user_create_request=user_create, is_superuser=True) | |
logger.info("First superuser created.") | |
except IntegrityError: | |
logger.info("First superuser already exists.") | |
finally: | |
session.close() | |
def remove_expired_tokens(db: Session) -> int: | |
""" | |
Returns the number of removed expired tokens (e.g., for logging) | |
""" | |
now = get_epoch() | |
filtered = db.query(RefreshToken).filter(RefreshToken.exp < now) | |
n_expired_tokens = filtered.count() | |
filtered.delete() | |
db.commit() | |
return n_expired_tokens | |
def get_epoch() -> int: | |
""" | |
Returns the number of seconds since the epoch | |
""" | |
return timegm(datetime.utcnow().utctimetuple()) | |
def logout_flow(db: Session, token: "Token") -> None: | |
""" | |
Deletes the provided refresh token | |
""" | |
_consume_refresh_token(db=db, token=token) | |
def logout_all_flow(db: Session, token: "Token") -> None: | |
""" | |
Deletes all refresh tokens for the token's specified user | |
""" | |
_consume_refresh_token(db=db, token=token) | |
db.query(RefreshToken).filter(RefreshToken.user_id == token.payload.user_id).delete(synchronize_session=False) | |
db.commit() | |
def _create_tokens(db: Session, user: UserInDB, scopes: List[str]) -> "TokenPair": | |
token_pair = generate_tokens(user=user, scopes=scopes) | |
refresh_token = RefreshToken( | |
token=token_pair.refresh.encoded, user_id=token_pair.refresh.payload.user_id, exp=token_pair.refresh.payload.exp | |
) | |
add_base(db, refresh_token) | |
return token_pair | |
def _consume_refresh_token(db: Session, token: "Token") -> None: | |
""" | |
First validates that it hasn't already been used, then deletes it | |
""" | |
if token.payload.sub != refresh_token_jwt_subject: | |
raise_auth_error(detail="Provided token was not a refresh token") | |
# Validate that the token is in the database, even though it passed decoding | |
# This is critical to ensure previously used / logged out tokens don't work | |
refresh_token = db.query(RefreshToken).filter(RefreshToken.token == token.encoded).first() | |
if refresh_token is None: | |
raise_auth_error(detail="Provided refresh token was invalid") | |
db.delete(refresh_token) | |
db.commit() | |
RouterBuilderType = Type[BaseAuthRouterBuilder[Any, Any, Any, Any, Any, Any]] | |
def get_auth_app( | |
router_builder: RouterBuilderType, | |
include_admin_routes: bool, | |
openapi_url: Optional[str] = None, | |
docs_url: Optional[str] = None, | |
redoc_url: Optional[str] = None, | |
swagger_ui_oauth2_redirect_url: Optional[str] = None, | |
**fastapi_kwargs: Any, | |
) -> FastAPI: | |
api_router = router_builder.get_router(include_admin_routes) | |
fastapi_kwargs.setdefault("debug", get_api_settings().debug) | |
auth_app = FastAPI( | |
openapi_url=openapi_url, | |
docs_url=docs_url, | |
redoc_url=redoc_url, | |
swagger_ui_oauth2_redirect_url=swagger_ui_oauth2_redirect_url, | |
**fastapi_kwargs, | |
) | |
auth_app.include_router(api_router) | |
return auth_app |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment