use argon2::{ password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHash, PasswordHasher, PasswordVerifier, }; use axum::{ extract::State, http::StatusCode, response::IntoResponse, routing::{get, post}, }; use axum_extra::{ headers::{authorization::Basic, Authorization}, TypedHeader, }; use uuid::Uuid; use crate::state::AppState; use self::{error::Error, jwt::JWT}; pub use self::claims::{AccessClaims, RefreshClaims}; pub mod claims; pub mod error; pub mod jwt; pub fn router(state: AppState) -> axum::Router { axum::Router::new() .route("/create", post(create)) .route("/issue", get(issue)) .route("/refresh", get(refresh)) .with_state(state) } pub async fn create( State(state): State, TypedHeader(Authorization(basic)): TypedHeader>, ) -> Result<(StatusCode, (AccessClaims, RefreshClaims)), Error> { let uuid = Uuid::try_parse(basic.username())?; let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::default().hash_password(basic.password().as_bytes(), &salt)?; let rows_affected = sqlx::query("INSERT INTO credentials (uuid,password_hash) VALUES ($1, $2)") .bind(uuid) .bind(password_hash.to_string()) .execute(&state.pool) .await? .rows_affected(); if rows_affected == 0 { Err(Error::Registration) } else { Ok(( StatusCode::CREATED, issue(State(state), TypedHeader(Authorization(basic))).await?, )) } } pub async fn issue( State(state): State, TypedHeader(Authorization(basic)): TypedHeader>, ) -> Result<(AccessClaims, RefreshClaims), Error> { let uuid = basic.username().try_into()?; let p: String = sqlx::query_scalar("SELECT password_hash FROM credentials WHERE uuid = $1") .bind(uuid) .fetch_optional(&state.pool) .await? .ok_or(Error::LoginInvalid)?; Argon2::default().verify_password(basic.password().as_bytes(), &PasswordHash::new(&p)?)?; let claims = RefreshClaims::new(uuid); Ok((claims.refresh(), claims)) } pub async fn refresh(claims: RefreshClaims) -> impl IntoResponse { claims.refresh() } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{header::AUTHORIZATION, Request, StatusCode}, }; use axum_extra::headers::authorization::Credentials; use sqlx::PgPool; use tower::ServiceExt; use crate::tests::{setup_test_env, TestResult}; #[test] fn test_jwt_encode_decode() -> TestResult { setup_test_env(); let claims = AccessClaims::new(uuid::Uuid::new_v4()); let token = JWT.encode(&claims)?; let decoded = JWT.decode(&token)?.claims; assert_eq!(claims, decoded); Ok(()) } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] async fn test_issue_ok(pool: PgPool) -> TestResult { setup_test_env(); let router = router(AppState { pool }); let auth = Authorization::basic( "4c14f795-86f0-4361-a02f-0edb966fb145", "solongandthanksforallthefish", ); let request = Request::builder() .uri("/issue") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; println!("{response:?}"); assert_eq!(StatusCode::OK, response.status()); Ok(()) } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] async fn test_issue_unauthorized(pool: PgPool) -> TestResult { setup_test_env(); let router = router(AppState { pool }); let auth = Authorization::basic("4c14f795-86f0-4361-a02f-0edb966fb145", "hunter2"); let request = Request::builder() .uri("/issue") .method("GET") .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::UNAUTHORIZED, response.status()); Ok(()) } }