use argon2::{ password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHash, PasswordHasher, PasswordVerifier, }; use axum::{extract::State, http::StatusCode, Router}; use axum_extra::{ headers::{authorization::Basic, Authorization}, routing::Resource, TypedHeader, }; use uuid::Uuid; use crate::state::AppState; use self::{error::Error, jwt::JWT}; pub use self::claims::{AccessClaims, RefreshClaims}; pub mod claims; pub mod error; pub mod jwt; pub fn router() -> Router { axum::Router::new().merge(Resource::named("users").index(issue).create(create)) } pub async fn issue( State(state): State, TypedHeader(Authorization(basic)): TypedHeader>, ) -> Result<(AccessClaims, RefreshClaims), Error> { let uuid = Uuid::try_parse(basic.username())?; let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", uuid) .fetch_optional(&state.pool) .await? .ok_or(Error::LoginInvalid)?; Argon2::default().verify_password(basic.password().as_bytes(), &PasswordHash::new(&p)?)?; let refresh = RefreshClaims::new(uuid); let access = refresh.refresh(); Ok((access, refresh)) } pub async fn create( State(state): State, TypedHeader(Authorization(basic)): TypedHeader>, ) -> Result<(StatusCode, (AccessClaims, RefreshClaims)), Error> { let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::default().hash_password(basic.password().as_bytes(), &salt)?; let uuid = sqlx::query!( "INSERT INTO credential (password_hash) VALUES ($1) RETURNING id", password_hash.to_string() ) .fetch_optional(&state.pool) .await? .ok_or(Error::Registration)? .id; let refresh = RefreshClaims::new(uuid); let access = refresh.refresh(); Ok((StatusCode::CREATED, (access, refresh))) } pub async fn refresh(claims: RefreshClaims) -> AccessClaims { claims.refresh() } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{header::AUTHORIZATION, Request, StatusCode}, Router, }; use axum_extra::headers::authorization::Credentials; use sqlx::PgPool; use tower::ServiceExt; use crate::tests::{setup_test_env, TestResult}; #[test] fn test_jwt_encode_decode() -> TestResult { setup_test_env(); let claims = AccessClaims::new(uuid::Uuid::new_v4()); let token = JWT.encode(&claims)?; let decoded = JWT.decode(&token)?.claims; assert_eq!(claims, decoded); Ok(()) } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] async fn test_issue_ok(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let auth = Authorization::basic( "4c14f795-86f0-4361-a02f-0edb966fb145", "solongandthanksforallthefish", ); let request = Request::builder() .uri("/users") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; println!("{response:?}"); assert_eq!(StatusCode::OK, response.status()); Ok(()) } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] async fn test_issue_unauthorized(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let auth = Authorization::basic("4c14f795-86f0-4361-a02f-0edb966fb145", "hunter2"); let request = Request::builder() .uri("/users") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::UNAUTHORIZED, response.status()); Ok(()) } }