Skip to content

Instantly share code, notes, and snippets.

@tropicbliss
Last active April 16, 2022 14:22
Show Gist options
  • Save tropicbliss/6af3d5a1aee86947df13199ec3d20ce4 to your computer and use it in GitHub Desktop.
Save tropicbliss/6af3d5a1aee86947df13199ec3d20ce4 to your computer and use it in GitHub Desktop.
Using JWT in Rust
use axum::{
async_trait,
extract::{FromRequest, RequestParts, TypedHeader},
handler::Handler,
headers::{authorization::Bearer, Authorization},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Extension, Json, Router,
};
use bcrypt::DEFAULT_COST;
use dotenv::dotenv;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::RwLock;
use validator::Validate;
static KEYS: Lazy<Keys> = Lazy::new(|| {
// let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
let secret = "lorem ipsum dolor sit amet";
Keys::new(secret.as_bytes())
});
#[tokio::main]
async fn main() {
dotenv().ok();
let accounts: HashMap<String, User> = HashMap::new();
let accounts = Arc::new(RwLock::new(accounts));
let app = Router::new()
.route("/register", post(register))
.route("/login", post(login))
.route("/protected", get(protected))
.layer(Extension(accounts))
.fallback(handler_404.into_service());
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(Serialize, Clone)]
struct User {
name: String,
hashed_password: String,
hint: Option<String>,
}
type SharedAccounts = Arc<RwLock<HashMap<String, User>>>;
async fn register(
Json(req): Json<RegisterRequest>,
Extension(accounts): Extension<SharedAccounts>,
) -> Result<Json<RegisterResponse>, AuthError> {
req.validate().map_err(|_| AuthError::ValidationError)?;
{
let accounts_data = &accounts.read().await;
if accounts_data.contains_key(&req.email) {
return Err(AuthError::UserAlreadyExists);
}
}
let phash_handle = tokio::task::spawn_blocking(|| bcrypt::hash(req.password, DEFAULT_COST));
let hashed_password = phash_handle
.await
.map_err(|_| AuthError::InternalServerError)?
.map_err(|_| AuthError::PasswordHashError)?;
let account = User {
name: req.name,
hashed_password,
hint: req.hint,
};
{
let accounts_data = &mut accounts.write().await;
accounts_data.insert(req.email, account);
}
Ok(Json(RegisterResponse { success: true }))
}
async fn protected(
claims: Claims,
Extension(accounts): Extension<SharedAccounts>,
) -> Result<Json<User>, AuthError> {
let email = claims.sub;
let accounts_data = &accounts.read().await;
let result = accounts_data.get(&email);
if let Some(user) = result {
Ok(Json(user.clone()))
} else {
Err(AuthError::UserDoesNotExist)
}
}
async fn login(
Json(req): Json<LoginRequest>,
Extension(accounts): Extension<SharedAccounts>,
) -> Result<Json<AuthBody>, AuthError> {
req.validate().map_err(|_| AuthError::ValidationError)?;
let result = {
let accounts_data = &accounts.read().await;
accounts_data.get(&req.email).cloned()
};
if let Some(user) = result {
let pverify_handle = tokio::task::spawn_blocking(move || {
bcrypt::verify(req.password, &user.hashed_password)
});
let result = pverify_handle
.await
.map_err(|_| AuthError::InternalServerError)?
.map_err(|_| AuthError::PasswordValidationError)?;
if !result {
return Err(AuthError::WrongCredentials);
}
} else {
return Err(AuthError::WrongCredentials);
}
let claims = Claims {
sub: req.email,
exp: jsonwebtoken::get_current_timestamp() + 86_400,
};
let token = encode(&Header::default(), &claims, &KEYS.encoding)
.map_err(|_| AuthError::TokenCreation)?;
Ok(Json(AuthBody::new(token)))
}
async fn handler_404() -> impl IntoResponse {
(StatusCode::NOT_FOUND, "nothing to see here")
}
#[derive(Serialize, Deserialize, Debug, Default, Validate)]
#[serde(default)]
struct RegisterRequest {
#[validate(length(min = 1, max = 27))]
name: String,
#[validate(email)]
email: String,
#[validate(length(min = 8, max = 27))]
password: String,
#[validate(length(min = 1, max = 27))]
hint: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Default, Validate)]
#[serde(default)]
struct LoginRequest {
#[validate(email)]
email: String,
#[validate(length(min = 8, max = 27))]
password: String,
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default)]
struct RegisterResponse {
success: bool,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
#[derive(Debug)]
enum AuthError {
WrongCredentials,
TokenCreation,
InvalidToken,
ValidationError,
UserAlreadyExists,
UserDoesNotExist,
InternalServerError,
PasswordHashError,
PasswordValidationError,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
AuthError::ValidationError => {
(StatusCode::BAD_REQUEST, "Invalid register request payload")
}
AuthError::UserAlreadyExists => (StatusCode::BAD_REQUEST, "User already exists"),
AuthError::UserDoesNotExist => (StatusCode::BAD_REQUEST, "User does not exist"),
AuthError::InternalServerError => {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
}
AuthError::PasswordHashError => {
(StatusCode::INTERNAL_SERVER_ERROR, "Password hash error")
}
AuthError::PasswordValidationError => (
StatusCode::INTERNAL_SERVER_ERROR,
"Password validation error",
),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
exp: u64,
}
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
#[async_trait]
impl<B> FromRequest<B> for Claims
where
B: Send,
{
type Rejection = AuthError;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
.await
.map_err(|_| AuthError::InvalidToken)?;
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment