use std::{str::FromStr, sync::Arc}; use argon2::{ password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher, }; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use axum_extra::routing::TypedPath; use serde::Deserialize; use crate::{ model::{RegisterSchema, User}, state::AppState, Error, }; #[derive(Debug, Deserialize, TypedPath)] #[typed_path("/api/register")] pub struct Register; impl Register { #[tracing::instrument(skip(password))] pub async fn post( self, State(state): State>, Json(RegisterSchema { name, email, password, }): Json, ) -> impl IntoResponse { email_address::EmailAddress::from_str(&email)?; let exists: Option = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)") .bind(email.to_ascii_lowercase()) .fetch_one(&state.pool) .await?; if exists.is_some_and(|b| b) { return Err(Error::EmailExists); } let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?; let user = sqlx::query_as!( User, "INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *", name, email.to_ascii_lowercase(), password_hash.to_string() ) .fetch_one(&state.pool) .await?; Ok((StatusCode::CREATED, Json(user))) } } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{header, Request, StatusCode}, }; use http_body_util::BodyExt; use sqlx::PgPool; use tower::ServiceExt; use crate::init_router; const JWT_SECRET: &str = "test-jwt-secret-token"; const JWT_MAX_AGE: time::Duration = time::Duration::HOUR; type TestResult> = std::result::Result; #[sqlx::test] async fn test_register_created(pool: PgPool) -> TestResult { let state = Arc::new(AppState { pool, jwt_secret: JWT_SECRET.to_string(), jwt_max_age: JWT_MAX_AGE, }); let router = init_router(state.clone()); let user = RegisterSchema { name: "Arthur Dent".to_string(), email: "adent@earth.sol".to_string(), password: "solongandthanksforallthefish".to_string(), }; let request = Request::builder() .uri("/api/register") .method("POST") .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from(serde_json::to_vec(&user)?))?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::CREATED, response.status()); let body_bytes = response.into_body().collect().await?.to_bytes(); let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; assert_eq!(user.name, name); assert_eq!(user.email, email); Ok(()) } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_register_conflict(pool: PgPool) -> TestResult { let state = Arc::new(AppState { pool, jwt_secret: JWT_SECRET.to_string(), jwt_max_age: JWT_MAX_AGE, }); let router = init_router(state.clone()); let user = RegisterSchema { name: "Arthur Dent".to_string(), email: "adent@earth.sol".to_string(), password: "solongandthanksforallthefish".to_string(), }; let request = Request::builder() .uri("/api/register") .method("POST") .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from(serde_json::to_vec(&user)?))?; let response = router.oneshot(request).await?; assert_eq!(StatusCode::CONFLICT, response.status()); Ok(()) } }