From a2860a1294b250402114fa016c4639881abc2172 Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Tue, 16 Apr 2024 20:28:46 -0500 Subject: refactor: improve account extractors --- src/api/account.rs | 64 ++++++++++++++++++++++++++++++++++-------------------- src/api/users.rs | 19 ++++++++-------- src/auth.rs | 40 +++++++++++++++++++++++++++------- src/auth/error.rs | 26 +++++++++++++++++++--- 4 files changed, 106 insertions(+), 43 deletions(-) (limited to 'src') diff --git a/src/api/account.rs b/src/api/account.rs index d6a94b5..0087df7 100644 --- a/src/api/account.rs +++ b/src/api/account.rs @@ -1,13 +1,19 @@ -use axum::{extract::State, routing::get, Router}; +use axum::{ + async_trait, + extract::{FromRequestParts, State}, + http::request::Parts, + routing::get, + RequestPartsExt, Router, +}; use axum_extra::{ + either::Either, extract::{cookie::Cookie, CookieJar}, headers::{authorization::Basic, Authorization}, - typed_header::TypedHeaderRejection, TypedHeader, }; use crate::{ - auth::{AccessClaims, RefreshClaims}, + auth::{AccessClaims, Account, RefreshClaims}, state::AppState, }; @@ -21,32 +27,44 @@ pub fn router() -> Router { pub async fn login( State(state): State, - auth: Result>, TypedHeaderRejection>, - claims: Option, -) -> Result<(AccessClaims, RefreshClaims), Error> { - if let Some(refresh_claims) = claims { - return Ok((refresh_claims.refresh(), refresh_claims)); + auth: Either, +) -> Result<(AccessClaims, RefreshClaims), crate::auth::error::Error> { + match auth { + Either::E1(token) => Ok((token.refresh(), token)), + Either::E2(Login(account)) => crate::auth::issue(State(state.clone()), account).await, } - - let TypedHeader(Authorization(basic)) = auth?; - - let user_id = sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) - .fetch_optional(&state.pool) - .await? - .ok_or(Error::UserNotFound)?; - - crate::auth::issue( - State(state.clone()), - TypedHeader(Authorization::basic(&user_id.to_string(), basic.password())), - ) - .await - .map_err(Into::into) } pub async fn logout(claims: AccessClaims, jar: CookieJar) -> Result { Ok(jar.remove(Cookie::try_from(claims)?)) } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Login(Account); + +#[async_trait] +impl FromRequestParts for Login { + type Rejection = Error; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let TypedHeader(Authorization(basic)) = + parts.extract::>>().await?; + + sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) + .fetch_optional(&state.pool) + .await? + .ok_or(Error::UserNotFound) + .map(|id| Account { + id, + password: basic.password().to_string(), + }) + .map(Self) + } +} + #[cfg(test)] mod tests { use super::*; @@ -60,7 +78,7 @@ mod tests { Router, }; - use axum_extra::headers::authorization::Credentials; + use axum_extra::headers::{authorization::Credentials, Authorization}; use http_body_util::BodyExt; use sqlx::PgPool; use tower::ServiceExt; diff --git a/src/api/users.rs b/src/api/users.rs index c8a390d..6ac0bb8 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -11,7 +11,10 @@ use sqlx::FromRow; use time::OffsetDateTime; use uuid::Uuid; -use crate::{auth::AccessClaims, state::AppState}; +use crate::{ + auth::{credentials::Credential, AccessClaims}, + state::AppState, +}; use super::error::Error; @@ -30,7 +33,7 @@ pub struct User { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct RegisterSchema { +pub struct Registration { pub name: String, pub email: String, pub password: String, @@ -38,11 +41,11 @@ pub struct RegisterSchema { pub async fn create( State(state): State, - Json(RegisterSchema { + Json(Registration { name, email, password, - }): Json, + }): Json, ) -> impl IntoResponse { email_address::EmailAddress::from_str(&email)?; @@ -58,11 +61,9 @@ pub async fn create( } // TODO: Move this into a micro service, possibly behind a feature flag. - let (status, (access, refresh)) = crate::auth::credentials::create( - State(state.clone()), - Json(crate::auth::credentials::Credential { password }), - ) - .await?; + let (status, (access, refresh)) = + crate::auth::credentials::create(State(state.clone()), Json(Credential { password })) + .await?; let user = sqlx::query_as!( User, diff --git a/src/auth.rs b/src/auth.rs index 09494fb..6bf0ddf 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,5 +1,10 @@ use argon2::{Argon2, PasswordHash, PasswordVerifier}; -use axum::extract::State; +use axum::{ + async_trait, + extract::{FromRequestParts, State}, + http::request::Parts, + RequestPartsExt, +}; use axum_extra::{ headers::{authorization::Basic, Authorization}, TypedHeader, @@ -28,18 +33,16 @@ pub fn router() -> axum::Router { pub async fn issue( State(state): State, - TypedHeader(Authorization(basic)): TypedHeader>, + Account { id, password }: Account, ) -> Result<(AccessClaims, RefreshClaims), Error> { - let uuid = Uuid::try_parse(basic.username())?; - - let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", uuid) + let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", id) .fetch_optional(&state.pool) .await? .ok_or(Error::LoginInvalid)?; - Argon2::default().verify_password(basic.password().as_bytes(), &PasswordHash::new(&p)?)?; + Argon2::default().verify_password(password.as_bytes(), &PasswordHash::new(&p)?)?; - let refresh = RefreshClaims::issue(uuid); + let refresh = RefreshClaims::issue(id); let access = refresh.refresh(); Ok((access, refresh)) } @@ -48,6 +51,27 @@ pub async fn refresh(claims: RefreshClaims) -> AccessClaims { claims.refresh() } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Account { + pub id: Uuid, + pub password: String, +} + +#[async_trait] +impl FromRequestParts for Account { + type Rejection = Error; + + async fn from_request_parts(parts: &mut Parts, _: &AppState) -> Result { + let TypedHeader(Authorization(basic)) = + parts.extract::>>().await?; + + Ok(Self { + id: Uuid::try_parse(basic.username())?, + password: basic.password().to_string(), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -57,7 +81,7 @@ mod tests { http::{header::AUTHORIZATION, Request, StatusCode}, Router, }; - use axum_extra::headers::authorization::Credentials; + use axum_extra::headers::{authorization::Credentials, Authorization}; use sqlx::PgPool; use tower::ServiceExt; diff --git a/src/auth/error.rs b/src/auth/error.rs index 17cf6d1..3a111ca 100644 --- a/src/auth/error.rs +++ b/src/auth/error.rs @@ -1,7 +1,13 @@ #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Failed to parse header: {0}")] - Header(#[from] axum::http::header::InvalidHeaderValue), + HeaderValue(#[from] axum::http::header::InvalidHeaderValue), + + #[error("Required header not found: {0}")] + HeaderNotFound(axum::http::HeaderName), + + #[error("Failed to parse header: {0} (wrong token type?)")] + HeaderRejection(axum_extra::typed_header::TypedHeaderRejection), #[error("Database error: {0}")] Sqlx(#[from] sqlx::Error), @@ -37,6 +43,16 @@ pub enum Error { UserNotFound, } +impl From for Error { + fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self { + if value.is_missing() { + Self::HeaderNotFound(value.name().clone()) + } else { + Self::HeaderRejection(value) + } + } +} + impl From for Error { fn from(value: argon2::password_hash::Error) -> Self { match value { @@ -73,10 +89,14 @@ impl From for Error { impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { - use axum::http::StatusCode; + use axum::http::{header::AUTHORIZATION, StatusCode}; let status = match self { - Error::JwtFormat(_) | Error::Uuid(_) => StatusCode::UNPROCESSABLE_ENTITY, + Self::HeaderNotFound(ref h) if h == AUTHORIZATION => StatusCode::UNAUTHORIZED, + Self::HeaderNotFound(_) => StatusCode::BAD_REQUEST, + Self::HeaderRejection(_) | Error::JwtFormat(_) | Error::Uuid(_) => { + StatusCode::UNPROCESSABLE_ENTITY + } Error::JwtValidation(_) | Error::LoginInvalid | Error::UserNotFound -- cgit v1.2.3-70-g09d2