From b383010105b79fcd4e9d671ce82f5c04d0fc9b13 Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Tue, 7 May 2024 16:45:13 -0500 Subject: fix(api): improve user create flow and move ... ...claim based tests from user module into claims module. --- src/api/error.rs | 19 ++++- src/api/users.rs | 184 +++++++++++++++--------------------------------- src/auth/claims.rs | 79 +++++++++++++++++++++ src/auth/credentials.rs | 6 +- 4 files changed, 156 insertions(+), 132 deletions(-) (limited to 'src') diff --git a/src/api/error.rs b/src/api/error.rs index 2af7228..10b5468 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -1,7 +1,7 @@ #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Database error: {0}")] - Sqlx(#[from] sqlx::Error), + Sqlx(#[source] sqlx::Error), #[error("Route not found: {0}")] RouteNotFound(axum::http::Uri), @@ -22,7 +22,7 @@ pub enum Error { InvalidToken, #[error("User with that email already exists")] - EmailExists, + UserExists, #[error("Invalid email: {0}")] EmailInvalid(#[from] email_address::Error), @@ -34,6 +34,19 @@ pub enum Error { Auth(#[from] crate::auth::error::Error), } +impl From for Error { + fn from(value: sqlx::Error) -> Self { + match value { + sqlx::Error::Database(db_err) + if db_err.is_unique_violation() && db_err.table().is_some_and(|s| s == "user_") => + { + Error::UserExists + } + err => Error::Sqlx(err), + } + } +} + impl From for Error { fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self { if value.is_missing() { @@ -53,7 +66,7 @@ impl axum::response::IntoResponse for Error { Self::RouteNotFound(_) | Self::UserNotFound | Self::TaskNotFound => { StatusCode::NOT_FOUND } - Self::EmailExists => StatusCode::CONFLICT, + Self::UserExists => StatusCode::CONFLICT, Self::InvalidToken => StatusCode::UNAUTHORIZED, Self::HeaderNotFound(ref h) if h == AUTHORIZATION => StatusCode::UNAUTHORIZED, Self::HeaderNotFound(_) => StatusCode::BAD_REQUEST, diff --git a/src/api/users.rs b/src/api/users.rs index e07bf7e..d4e5d57 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -46,36 +46,29 @@ pub async fn create( email, password, }): Json, -) -> impl IntoResponse { +) -> Result { email_address::EmailAddress::from_str(&email)?; - let exists: Option = sqlx::query_scalar!( - "SELECT EXISTS(SELECT 1 FROM user_ WHERE email = $1 LIMIT 1)", - email.to_ascii_lowercase() - ) - .fetch_one(&pool) - .await?; - - if exists.is_some_and(|b| b) { - return Err(Error::EmailExists); - } - - // TODO: Move this into a micro service, possibly behind a feature flag. - let (status, (access, refresh)) = - crate::auth::credentials::create(State(pool.clone()), Json(Credential { password })) - .await?; - let user = sqlx::query_as!( User, - "INSERT INTO user_ (id,name,email) VALUES ($1, $2, $3) RETURNING *", - refresh.sub, + "INSERT INTO user_ (name,email) VALUES ($1, $2) RETURNING *", name, email.to_ascii_lowercase(), ) .fetch_one(&pool) .await?; - Ok((status, access, refresh, Json(user))) + // TODO: Move this into a micro service, possibly behind a feature flag. + crate::auth::credentials::create( + State(pool.clone()), + Json(Credential { + id: user.id, + password, + }), + ) + .await + .map(|(status, claims)| (status, claims, Json(user))) + .map_err(Into::into) } pub async fn show( @@ -116,30 +109,31 @@ mod tests { const USER_EMAIL: &str = "adent@earth.sol"; const USER_PASSWORD: &str = "solongandthanksforallthefish"; - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_get_ok_self(pool: PgPool) -> TestResult { + #[sqlx::test] + async fn test_users_post_created(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); + let user = serde_json::json!( { + "name": USER_NAME, + "email": USER_EMAIL, + "password": USER_PASSWORD, + }); + let request = Request::builder() - .uri(format!("/users/{}", USER_ID)) - .header( - COOKIE, - AccessClaims::issue(USER_ID).as_cookie()?.to_string(), - ) - .body(Body::empty())?; + .uri("/users") + .method("POST") + .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user)?))?; let response = router.oneshot(request).await?; - assert_eq!(StatusCode::OK, response.status()); + assert_eq!(StatusCode::CREATED, response.status()); let body_bytes = response.into_body().collect().await?.to_bytes(); - let User { - id, name, email, .. - } = serde_json::from_slice(&body_bytes)?; + let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; - assert_eq!(USER_ID, id); assert_eq!(USER_NAME, name); assert_eq!(USER_EMAIL, email); @@ -147,7 +141,32 @@ mod tests { } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_get_ok_other(pool: PgPool) -> TestResult { + async fn test_users_post_conflict(pool: PgPool) -> TestResult { + setup_test_env(); + + let router = Router::new().merge(router()).with_state(AppState { pool }); + + let user = serde_json::json!( { + "name": USER_NAME, + "email": USER_EMAIL, + "password": USER_PASSWORD, + }); + + let request = Request::builder() + .uri("/users") + .method("POST") + .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user)?))?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::CONFLICT, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_users_get_ok(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -179,7 +198,7 @@ mod tests { } #[sqlx::test] - async fn test_get_not_found(pool: PgPool) -> TestResult { + async fn test_users_get_not_found(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -188,7 +207,9 @@ mod tests { .uri(format!("/users/{}", USER_ID)) .header( COOKIE, - AccessClaims::issue(USER_ID).as_cookie()?.to_string(), + AccessClaims::issue(uuid::Uuid::new_v4()) + .as_cookie()? + .to_string(), ) .body(Body::empty())?; @@ -198,95 +219,4 @@ mod tests { Ok(()) } - - #[sqlx::test] - async fn test_get_unauthorized_invalid_token_format(pool: PgPool) -> TestResult { - setup_test_env(); - - let router = Router::new().merge(router()).with_state(AppState { pool }); - - let request = Request::builder() - .uri(format!("/users/{}", USER_ID)) - .header(COOKIE, "token=sadfasdfsdfs") - .body(Body::empty())?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::UNPROCESSABLE_ENTITY, response.status()); - - Ok(()) - } - - #[sqlx::test] - async fn test_get_unauthorized_missing_token(pool: PgPool) -> TestResult { - setup_test_env(); - - let router = Router::new().merge(router()).with_state(AppState { pool }); - - let request = Request::builder() - .uri(format!("/users/{}", USER_ID)) - .body(Body::empty())?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::UNAUTHORIZED, response.status()); - - Ok(()) - } - - #[sqlx::test] - async fn test_post_created(pool: PgPool) -> TestResult { - setup_test_env(); - - let router = Router::new().merge(router()).with_state(AppState { pool }); - - let user = serde_json::json!( { - "name": USER_NAME, - "email": USER_EMAIL, - "password": USER_PASSWORD, - }); - - let request = Request::builder() - .uri("/users") - .method("POST") - .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from(serde_json::to_vec(&user)?))?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::CREATED, response.status()); - - let body_bytes = response.into_body().collect().await?.to_bytes(); - let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; - - assert_eq!(USER_NAME, name); - assert_eq!(USER_EMAIL, email); - - Ok(()) - } - - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_post_conflict(pool: PgPool) -> TestResult { - setup_test_env(); - - let router = Router::new().merge(router()).with_state(AppState { pool }); - - let user = serde_json::json!( { - "name": USER_NAME, - "email": USER_EMAIL, - "password": USER_PASSWORD, - }); - - let request = Request::builder() - .uri("/users") - .method("POST") - .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from(serde_json::to_vec(&user)?))?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::CONFLICT, response.status()); - - Ok(()) - } } diff --git a/src/auth/claims.rs b/src/auth/claims.rs index 652a47f..6940844 100644 --- a/src/auth/claims.rs +++ b/src/auth/claims.rs @@ -210,3 +210,82 @@ where JWT.decode(bearer.token()).map(|jwt| jwt.claims) } } + +#[cfg(test)] +mod tests { + use super::*; + + use axum::{ + body::Body, + http::{header::COOKIE, Request, StatusCode}, + }; + + use tower::ServiceExt; + + use crate::tests::{setup_test_env, TestResult}; + + pub fn router() -> axum::Router<()> { + use axum::routing::get; + + axum::Router::new().route( + "/test", + get(|_: AccessClaims| async { axum::http::StatusCode::OK }), + ) + } + + #[tokio::test] + async fn test_claims_get_ok() -> TestResult { + setup_test_env(); + + let router = router(); + + let request = Request::builder() + .uri("/test") + .header( + COOKIE, + AccessClaims::issue(uuid::Uuid::new_v4()) + .as_cookie()? + .to_string(), + ) + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::OK, response.status()); + + Ok(()) + } + + #[tokio::test] + async fn test_claims_get_unauthorized() -> TestResult { + setup_test_env(); + + let router = router(); + + let request = Request::builder().uri("/test").body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[tokio::test] + async fn test_claims_get_unprocessable_entity() -> TestResult { + setup_test_env(); + + let router = router(); + + let request = Request::builder() + .uri("/test") + .header(COOKIE, "token=sadfasdfsdfs") + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::UNPROCESSABLE_ENTITY, response.status()); + + Ok(()) + } +} diff --git a/src/auth/credentials.rs b/src/auth/credentials.rs index 2ba3f29..718749f 100644 --- a/src/auth/credentials.rs +++ b/src/auth/credentials.rs @@ -19,6 +19,7 @@ use super::{error::Error, AccessClaims, RefreshClaims}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Credential { + pub id: Uuid, pub password: String, } @@ -30,13 +31,14 @@ pub fn router() -> Resource { pub async fn create( State(pool): State, - Json(Credential { password }): Json, + Json(Credential { id, password }): Json, ) -> Result<(StatusCode, (AccessClaims, RefreshClaims)), Error> { let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?; let uuid = sqlx::query!( - "INSERT INTO credential (password_hash) VALUES ($1) RETURNING id", + "INSERT INTO credential (id, password_hash) VALUES ($1, $2) RETURNING id", + id, password_hash.to_string() ) .fetch_optional(&pool) -- cgit v1.2.3-70-g09d2