use axum::{ async_trait, extract::{FromRequestParts, State}, http::request::Parts, routing::get, RequestPartsExt, Router, }; use axum_extra::{ either::Either, extract::CookieJar, headers::{authorization::Basic, Authorization}, TypedHeader, }; use sqlx::PgPool; use crate::{ auth::{AccessClaims, Account, RefreshClaims}, state::AppState, }; use super::error::Error; pub fn router() -> Router { axum::Router::new() .route("/login", get(login)) .route("/logout", get(logout)) } pub async fn login( State(pool): State, auth: Either, ) -> Result<(AccessClaims, RefreshClaims), crate::auth::error::Error> { match auth { Either::E1(token) => Ok((token.refresh(), token)), Either::E2(Login(account)) => crate::auth::issue(State(pool), account).await, } } pub async fn logout(claims: AccessClaims, jar: CookieJar) -> Result { Ok(jar.remove(("token", crate::auth::jwt::JWT.encode(&claims)?))) } #[derive(Debug, Clone, PartialEq, Eq)] pub struct Login(Account); #[async_trait] impl FromRequestParts for Login { type Rejection = Error; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { let TypedHeader(Authorization(basic)) = parts.extract::>>().await?; sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) .fetch_optional(&state.pool) .await? .ok_or(Error::UserNotFound) .map(|id| Account { id, password: basic.password().to_string(), }) .map(Self) } } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{ header::{AUTHORIZATION, COOKIE, SET_COOKIE}, Request, StatusCode, }, Router, }; use axum_extra::headers::{authorization::Credentials, Authorization}; use http_body_util::BodyExt; use tower::ServiceExt; use uuid::Uuid; use crate::{ auth::AccessClaims, tests::{setup_test_env, TestResult}, }; const USER_ID: Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"); const USER_EMAIL: &str = "adent@earth.sol"; const USER_PASSWORD: &str = "solongandthanksforallthefish"; #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_login_ok(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let auth = Authorization::basic(USER_EMAIL, USER_PASSWORD); let request = Request::builder() .uri("/login") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let (mut parts, body) = router.oneshot(request).await?.into_parts(); assert_eq!(StatusCode::OK, parts.status); let body_bytes = body.collect().await?.to_bytes(); let body = std::str::from_utf8(&body_bytes)?; let refresh_claims: RefreshClaims = crate::auth::jwt::JWT.decode(body)?.claims; assert_eq!(USER_ID, refresh_claims.sub); let set_cookie = parts .headers .get(SET_COOKIE) .expect("Failed to get set-header cookie"); parts.headers.insert(COOKIE, set_cookie.clone()); let jar = CookieJar::from_headers(&parts.headers); let cookie = jar .get("token") .expect("'token' cookie not found in response cookie jar"); let access_claims: AccessClaims = crate::auth::jwt::JWT.decode(cookie.value())?.claims; assert_eq!(USER_ID, access_claims.sub); Ok(()) } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_login_unauthorized(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let auth = Authorization::basic(USER_EMAIL, "hunter2"); let request = Request::builder() .uri("/login") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::UNAUTHORIZED, response.status()); Ok(()) } #[sqlx::test] async fn test_login_not_found(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let auth = Authorization::basic(USER_EMAIL, USER_PASSWORD); let request = Request::builder() .uri("/login") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::NOT_FOUND, response.status()); Ok(()) } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_logout_ok(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let request = Request::builder() .uri("/logout") .method("GET") .header( COOKIE, AccessClaims::issue(USER_ID).as_cookie()?.to_string(), ) .body(Body::empty())?; let (mut parts, _) = router.oneshot(request).await?.into_parts(); assert_eq!(StatusCode::OK, parts.status); let set_cookie = parts .headers .get(SET_COOKIE) .expect("Failed to get set-header cookie"); parts.headers.insert(COOKIE, set_cookie.clone()); let jar = CookieJar::from_headers(&parts.headers); let cookie = jar .get("token") .expect("'token' cookie not found in response cookie jar"); assert_eq!(cookie.value(), ""); assert_eq!(cookie.max_age(), None); Ok(()) } }