From ce961ca85ba96813ccdca9be1d18ee11e4e0d25c Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Tue, 26 Mar 2024 21:04:02 -0500 Subject: feat: add user database and registration --- src/routes.rs | 167 ++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 144 insertions(+), 23 deletions(-) (limited to 'src/routes.rs') diff --git a/src/routes.rs b/src/routes.rs index 0a81317..0bf34b2 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,22 +1,33 @@ use std::sync::Arc; +use argon2::{ + password_hash::{rand_core::OsRng, SaltString}, + Argon2, PasswordHasher, +}; use axum::{ extract::State, http::{StatusCode, Uri}, + response::IntoResponse, Json, }; use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; -use crate::{model::User, state::AppState, Error}; +use crate::{ + model::{RegisterSchema, User}, + state::AppState, + Error, +}; -pub fn router(state: AppState) -> axum::Router { +#[tracing::instrument] +pub fn router(state: Arc) -> axum::Router { axum::Router::new() // .route("/api/user", get(get_user)) .typed_get(HealthCheck::get) - .typed_get(UserId::get) + .typed_get(UserUuid::get) + .typed_post(Register::post) .fallback(fallback) - .with_state(Arc::new(state)) + .with_state(state) } #[derive(Debug, Deserialize, TypedPath)] @@ -24,7 +35,8 @@ pub fn router(state: AppState) -> axum::Router { pub struct HealthCheck; impl HealthCheck { - pub async fn get(self) -> Json { + #[tracing::instrument] + pub async fn get(self) -> impl IntoResponse { const MESSAGE: &str = "Unnamed server"; let json_response = serde_json::json!({ @@ -38,26 +50,62 @@ impl HealthCheck { #[derive(Debug, Deserialize, TypedPath)] #[typed_path("/api/user/:uuid")] -pub struct UserId { +pub struct UserUuid { pub uuid: uuid::Uuid, } -impl UserId { - /// Get a user via their `id` - #[tracing::instrument(ret, skip(data))] - pub async fn get( - self, - State(data): State>, - ) -> Result, Error> { +impl UserUuid { + /// Get a user with a specific `uuid` + #[tracing::instrument] + pub async fn get(self, State(state): State>) -> impl IntoResponse { sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid) - .fetch_optional(&data.db_pool) + .fetch_optional(&state.pool) .await? - .ok_or_else(|| Error::UserNotFound(self.uuid)) - .map(User::into_query_response) + .ok_or_else(|| Error::UserNotFound) + .map(Json) + } +} + +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/api/user/register")] +pub struct Register; + +impl Register { + #[tracing::instrument(skip(register_schema))] + pub async fn post( + self, + State(state): State>, + Json(register_schema): Json, + ) -> impl IntoResponse { + let exists: Option = + sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)") + .bind(register_schema.email.to_ascii_lowercase()) + .fetch_one(&state.pool) + .await?; + + if exists.is_some_and(|b| b) { + return Err(Error::EmailExists); + } + + let salt = SaltString::generate(&mut OsRng); + let hashed_password = + Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?; + + let user = sqlx::query_as!( + User, + "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *", + register_schema.name, + register_schema.email.to_ascii_lowercase(), + hashed_password.to_string() + ) + .fetch_one(&state.pool) + .await?; + + Ok((StatusCode::CREATED, Json(user))) } } -pub async fn fallback(uri: Uri) -> (StatusCode, String) { +pub async fn fallback(uri: Uri) -> impl IntoResponse { (StatusCode::NOT_FOUND, format!("Route not found: {uri}")) } @@ -65,15 +113,88 @@ pub async fn fallback(uri: Uri) -> (StatusCode, String) { mod tests { use super::*; - use axum_test::TestServer; + use axum::{ + body::Body, + http::{header, Request, StatusCode}, + }; + use http_body_util::BodyExt; + use sqlx::PgPool; + use tower::ServiceExt; + + #[sqlx::test] + async fn test_fallback(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let response = router + .oneshot( + Request::builder() + .uri("/does-not-exist") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(StatusCode::NOT_FOUND, response.status()); - #[tokio::test] - async fn test_fallback() -> Result<(), Box> { - let server = TestServer::new(axum::Router::new().fallback(fallback))?; + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_user(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1") + .fetch_one(&state.pool) + .await?; - let response = server.get("/fallback").await; + let response = router + .oneshot( + Request::builder() + .uri(format!("/api/user/{}", user.id)) + .body(Body::empty())?, + ) + .await + .unwrap(); + + assert_eq!(StatusCode::OK, response.status()); + + Ok(()) + } - assert_eq!(StatusCode::NOT_FOUND, response.status_code()); + #[sqlx::test] + async fn test_user_register(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let register_user = RegisterSchema { + name: "Ford Prefect".to_string(), + email: "fprefect@heartofgold.galaxy".to_string(), + password: "42".to_string(), + }; + + let response = router + .oneshot( + Request::builder() + .uri("/api/user/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from( + serde_json::to_vec(&serde_json::json!(register_user)).unwrap(), + ))?, + ) + .await + .unwrap(); + + assert_eq!(StatusCode::CREATED, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let user: User = serde_json::from_slice(&body_bytes)?; + + assert_eq!(register_user.name, user.name); + assert_eq!(register_user.email, user.email); Ok(()) } -- cgit v1.2.3-70-g09d2