Last active
April 16, 2022 14:22
-
-
Save tropicbliss/6af3d5a1aee86947df13199ec3d20ce4 to your computer and use it in GitHub Desktop.
Using JWT in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use axum::{ | |
async_trait, | |
extract::{FromRequest, RequestParts, TypedHeader}, | |
handler::Handler, | |
headers::{authorization::Bearer, Authorization}, | |
http::StatusCode, | |
response::{IntoResponse, Response}, | |
routing::{get, post}, | |
Extension, Json, Router, | |
}; | |
use bcrypt::DEFAULT_COST; | |
use dotenv::dotenv; | |
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; | |
use once_cell::sync::Lazy; | |
use serde::{Deserialize, Serialize}; | |
use serde_json::json; | |
use std::{collections::HashMap, net::SocketAddr, sync::Arc}; | |
use tokio::sync::RwLock; | |
use validator::Validate; | |
static KEYS: Lazy<Keys> = Lazy::new(|| { | |
// let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); | |
let secret = "lorem ipsum dolor sit amet"; | |
Keys::new(secret.as_bytes()) | |
}); | |
#[tokio::main] | |
async fn main() { | |
dotenv().ok(); | |
let accounts: HashMap<String, User> = HashMap::new(); | |
let accounts = Arc::new(RwLock::new(accounts)); | |
let app = Router::new() | |
.route("/register", post(register)) | |
.route("/login", post(login)) | |
.route("/protected", get(protected)) | |
.layer(Extension(accounts)) | |
.fallback(handler_404.into_service()); | |
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); | |
tracing::debug!("listening on {}", addr); | |
axum::Server::bind(&addr) | |
.serve(app.into_make_service()) | |
.await | |
.unwrap(); | |
} | |
#[derive(Serialize, Clone)] | |
struct User { | |
name: String, | |
hashed_password: String, | |
hint: Option<String>, | |
} | |
type SharedAccounts = Arc<RwLock<HashMap<String, User>>>; | |
async fn register( | |
Json(req): Json<RegisterRequest>, | |
Extension(accounts): Extension<SharedAccounts>, | |
) -> Result<Json<RegisterResponse>, AuthError> { | |
req.validate().map_err(|_| AuthError::ValidationError)?; | |
{ | |
let accounts_data = &accounts.read().await; | |
if accounts_data.contains_key(&req.email) { | |
return Err(AuthError::UserAlreadyExists); | |
} | |
} | |
let phash_handle = tokio::task::spawn_blocking(|| bcrypt::hash(req.password, DEFAULT_COST)); | |
let hashed_password = phash_handle | |
.await | |
.map_err(|_| AuthError::InternalServerError)? | |
.map_err(|_| AuthError::PasswordHashError)?; | |
let account = User { | |
name: req.name, | |
hashed_password, | |
hint: req.hint, | |
}; | |
{ | |
let accounts_data = &mut accounts.write().await; | |
accounts_data.insert(req.email, account); | |
} | |
Ok(Json(RegisterResponse { success: true })) | |
} | |
async fn protected( | |
claims: Claims, | |
Extension(accounts): Extension<SharedAccounts>, | |
) -> Result<Json<User>, AuthError> { | |
let email = claims.sub; | |
let accounts_data = &accounts.read().await; | |
let result = accounts_data.get(&email); | |
if let Some(user) = result { | |
Ok(Json(user.clone())) | |
} else { | |
Err(AuthError::UserDoesNotExist) | |
} | |
} | |
async fn login( | |
Json(req): Json<LoginRequest>, | |
Extension(accounts): Extension<SharedAccounts>, | |
) -> Result<Json<AuthBody>, AuthError> { | |
req.validate().map_err(|_| AuthError::ValidationError)?; | |
let result = { | |
let accounts_data = &accounts.read().await; | |
accounts_data.get(&req.email).cloned() | |
}; | |
if let Some(user) = result { | |
let pverify_handle = tokio::task::spawn_blocking(move || { | |
bcrypt::verify(req.password, &user.hashed_password) | |
}); | |
let result = pverify_handle | |
.await | |
.map_err(|_| AuthError::InternalServerError)? | |
.map_err(|_| AuthError::PasswordValidationError)?; | |
if !result { | |
return Err(AuthError::WrongCredentials); | |
} | |
} else { | |
return Err(AuthError::WrongCredentials); | |
} | |
let claims = Claims { | |
sub: req.email, | |
exp: jsonwebtoken::get_current_timestamp() + 86_400, | |
}; | |
let token = encode(&Header::default(), &claims, &KEYS.encoding) | |
.map_err(|_| AuthError::TokenCreation)?; | |
Ok(Json(AuthBody::new(token))) | |
} | |
async fn handler_404() -> impl IntoResponse { | |
(StatusCode::NOT_FOUND, "nothing to see here") | |
} | |
#[derive(Serialize, Deserialize, Debug, Default, Validate)] | |
#[serde(default)] | |
struct RegisterRequest { | |
#[validate(length(min = 1, max = 27))] | |
name: String, | |
#[validate(email)] | |
email: String, | |
#[validate(length(min = 8, max = 27))] | |
password: String, | |
#[validate(length(min = 1, max = 27))] | |
hint: Option<String>, | |
} | |
#[derive(Serialize, Deserialize, Debug, Default, Validate)] | |
#[serde(default)] | |
struct LoginRequest { | |
#[validate(email)] | |
email: String, | |
#[validate(length(min = 8, max = 27))] | |
password: String, | |
} | |
#[derive(Serialize, Deserialize, Debug, Default)] | |
#[serde(default)] | |
struct RegisterResponse { | |
success: bool, | |
} | |
#[derive(Debug, Serialize)] | |
struct AuthBody { | |
access_token: String, | |
token_type: String, | |
} | |
impl AuthBody { | |
fn new(access_token: String) -> Self { | |
Self { | |
access_token, | |
token_type: "Bearer".to_string(), | |
} | |
} | |
} | |
#[derive(Debug)] | |
enum AuthError { | |
WrongCredentials, | |
TokenCreation, | |
InvalidToken, | |
ValidationError, | |
UserAlreadyExists, | |
UserDoesNotExist, | |
InternalServerError, | |
PasswordHashError, | |
PasswordValidationError, | |
} | |
impl IntoResponse for AuthError { | |
fn into_response(self) -> Response { | |
let (status, error_message) = match self { | |
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), | |
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), | |
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), | |
AuthError::ValidationError => { | |
(StatusCode::BAD_REQUEST, "Invalid register request payload") | |
} | |
AuthError::UserAlreadyExists => (StatusCode::BAD_REQUEST, "User already exists"), | |
AuthError::UserDoesNotExist => (StatusCode::BAD_REQUEST, "User does not exist"), | |
AuthError::InternalServerError => { | |
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error") | |
} | |
AuthError::PasswordHashError => { | |
(StatusCode::INTERNAL_SERVER_ERROR, "Password hash error") | |
} | |
AuthError::PasswordValidationError => ( | |
StatusCode::INTERNAL_SERVER_ERROR, | |
"Password validation error", | |
), | |
}; | |
let body = Json(json!({ | |
"error": error_message, | |
})); | |
(status, body).into_response() | |
} | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
struct Claims { | |
sub: String, | |
exp: u64, | |
} | |
struct Keys { | |
encoding: EncodingKey, | |
decoding: DecodingKey, | |
} | |
impl Keys { | |
fn new(secret: &[u8]) -> Self { | |
Self { | |
encoding: EncodingKey::from_secret(secret), | |
decoding: DecodingKey::from_secret(secret), | |
} | |
} | |
} | |
#[async_trait] | |
impl<B> FromRequest<B> for Claims | |
where | |
B: Send, | |
{ | |
type Rejection = AuthError; | |
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { | |
let TypedHeader(Authorization(bearer)) = | |
TypedHeader::<Authorization<Bearer>>::from_request(req) | |
.await | |
.map_err(|_| AuthError::InvalidToken)?; | |
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()) | |
.map_err(|_| AuthError::InvalidToken)?; | |
Ok(token_data.claims) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment