From a2860a1294b250402114fa016c4639881abc2172 Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Tue, 16 Apr 2024 20:28:46 -0500 Subject: refactor: improve account extractors --- src/api/account.rs | 64 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 23 deletions(-) (limited to 'src/api/account.rs') diff --git a/src/api/account.rs b/src/api/account.rs index d6a94b5..0087df7 100644 --- a/src/api/account.rs +++ b/src/api/account.rs @@ -1,13 +1,19 @@ -use axum::{extract::State, routing::get, Router}; +use axum::{ + async_trait, + extract::{FromRequestParts, State}, + http::request::Parts, + routing::get, + RequestPartsExt, Router, +}; use axum_extra::{ + either::Either, extract::{cookie::Cookie, CookieJar}, headers::{authorization::Basic, Authorization}, - typed_header::TypedHeaderRejection, TypedHeader, }; use crate::{ - auth::{AccessClaims, RefreshClaims}, + auth::{AccessClaims, Account, RefreshClaims}, state::AppState, }; @@ -21,32 +27,44 @@ pub fn router() -> Router { pub async fn login( State(state): State, - auth: Result>, TypedHeaderRejection>, - claims: Option, -) -> Result<(AccessClaims, RefreshClaims), Error> { - if let Some(refresh_claims) = claims { - return Ok((refresh_claims.refresh(), refresh_claims)); + auth: Either, +) -> Result<(AccessClaims, RefreshClaims), crate::auth::error::Error> { + match auth { + Either::E1(token) => Ok((token.refresh(), token)), + Either::E2(Login(account)) => crate::auth::issue(State(state.clone()), account).await, } - - let TypedHeader(Authorization(basic)) = auth?; - - let user_id = sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) - .fetch_optional(&state.pool) - .await? - .ok_or(Error::UserNotFound)?; - - crate::auth::issue( - State(state.clone()), - TypedHeader(Authorization::basic(&user_id.to_string(), basic.password())), - ) - .await - .map_err(Into::into) } pub async fn logout(claims: AccessClaims, jar: CookieJar) -> Result { Ok(jar.remove(Cookie::try_from(claims)?)) } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Login(Account); + +#[async_trait] +impl FromRequestParts for Login { + type Rejection = Error; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let TypedHeader(Authorization(basic)) = + parts.extract::>>().await?; + + sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) + .fetch_optional(&state.pool) + .await? + .ok_or(Error::UserNotFound) + .map(|id| Account { + id, + password: basic.password().to_string(), + }) + .map(Self) + } +} + #[cfg(test)] mod tests { use super::*; @@ -60,7 +78,7 @@ mod tests { Router, }; - use axum_extra::headers::authorization::Credentials; + use axum_extra::headers::{authorization::Credentials, Authorization}; use http_body_util::BodyExt; use sqlx::PgPool; use tower::ServiceExt; -- cgit v1.2.3-70-g09d2