diff options
Diffstat (limited to 'src/routes.rs')
-rw-r--r-- | src/routes.rs | 297 |
1 files changed, 243 insertions, 54 deletions
diff --git a/src/routes.rs b/src/routes.rs index 2692f1a..1ec4e30 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,31 +1,36 @@ -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use argon2::{ password_hash::{rand_core::OsRng, SaltString}, - Argon2, PasswordHasher, + Argon2, PasswordHash, PasswordHasher, PasswordVerifier, }; use axum::{ extract::State, - http::{StatusCode, Uri}, + http::{header::SET_COOKIE, StatusCode, Uri}, response::IntoResponse, Json, }; -use axum_extra::routing::{RouterExt, TypedPath}; +use axum_extra::{ + extract::cookie::{Cookie, SameSite}, + routing::{RouterExt, TypedPath}, +}; +use jsonwebtoken::{EncodingKey, Header}; use serde::Deserialize; use crate::{ - model::{RegisterSchema, User}, + model::{LoginSchema, RegisterSchema, TokenClaims, User}, state::AppState, Error, }; #[tracing::instrument] -pub fn router(state: Arc<AppState>) -> axum::Router { +pub fn init_router(state: Arc<AppState>) -> axum::Router { axum::Router::new() // .route("/api/user", get(get_user)) .typed_get(HealthCheck::get) .typed_get(UserUuid::get) .typed_post(Register::post) + .typed_post(Login::post) .fallback(fallback) .with_state(state) } @@ -58,7 +63,7 @@ impl UserUuid { /// Get a user with a specific `uuid` #[tracing::instrument] pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse { - sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid) + sqlx::query_as!(User, "SELECT * FROM users WHERE uuid = $1", self.uuid) .fetch_optional(&state.pool) .await? .ok_or_else(|| Error::UserNotFound) @@ -67,21 +72,25 @@ impl UserUuid { } #[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/user/register")] +#[typed_path("/api/register")] pub struct Register; impl Register { - #[tracing::instrument(skip(register_schema))] + #[tracing::instrument(skip(password))] pub async fn post( self, State(state): State<Arc<AppState>>, - Json(register_schema): Json<RegisterSchema>, + Json(RegisterSchema { + name, + email, + password, + }): Json<RegisterSchema>, ) -> impl IntoResponse { - register_schema.validate()?; + email_address::EmailAddress::from_str(&email)?; let exists: Option<bool> = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)") - .bind(register_schema.email.to_ascii_lowercase()) + .bind(email.to_ascii_lowercase()) .fetch_one(&state.pool) .await?; @@ -90,15 +99,14 @@ impl Register { } let salt = SaltString::generate(&mut OsRng); - let hashed_password = - Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?; + let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?; let user = sqlx::query_as!( User, - "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *", - register_schema.name, - register_schema.email.to_ascii_lowercase(), - hashed_password.to_string() + "INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *", + name, + email.to_ascii_lowercase(), + password_hash.to_string() ) .fetch_one(&state.pool) .await?; @@ -107,6 +115,56 @@ impl Register { } } +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/api/login")] +pub struct Login; + +impl Login { + #[tracing::instrument(skip(state, password))] + pub async fn post( + self, + State(state): State<Arc<AppState>>, + Json(LoginSchema { email, password }): Json<LoginSchema>, + ) -> Result<impl IntoResponse, Error> { + let User { + uuid, + password_hash, + .. + } = sqlx::query_as!( + User, + "SELECT * FROM users WHERE email = $1", + email.to_ascii_lowercase() + ) + .fetch_optional(&state.pool) + .await? + .ok_or(Error::LoginInvalid)?; + + Argon2::default() + .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?; + + let token = jsonwebtoken::encode( + &Header::default(), + &TokenClaims::new(uuid, state.jwt_max_age), + &EncodingKey::from_secret(state.jwt_secret.as_ref()), + )?; + + let cookie = Cookie::build(("token", token.to_owned())) + .path("/") + .max_age(state.jwt_max_age) + .same_site(SameSite::Lax) + .http_only(true) + .build(); + + let mut response = Json(token).into_response(); + + response + .headers_mut() + .insert(SET_COOKIE, cookie.to_string().parse().unwrap()); + + Ok(response) + } +} + pub async fn fallback(uri: Uri) -> impl IntoResponse { (StatusCode::NOT_FOUND, format!("Route not found: {uri}")) } @@ -123,10 +181,17 @@ mod tests { use sqlx::PgPool; use tower::ServiceExt; + const JWT_SECRET: &str = "test-jwt-secret-token"; + const JWT_MAX_AGE: time::Duration = time::Duration::HOUR; + #[sqlx::test] - async fn test_fallback(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); + async fn test_route_not_found(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); let response = router .oneshot( @@ -144,59 +209,183 @@ mod tests { } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] - async fn test_user(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); + async fn test_user_ok(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); - let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1") - .fetch_one(&state.pool) - .await?; + let user = User { + uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + ..Default::default() + }; - let response = router - .oneshot( - Request::builder() - .uri(format!("/api/user/{}", user.id)) - .body(Body::empty())?, - ) - .await - .unwrap(); + let request = Request::builder() + .uri(format!("/api/user/{}", user.uuid)) + .body(Body::empty())?; + + let response = router.oneshot(request).await.unwrap(); assert_eq!(StatusCode::OK, response.status()); + let body_bytes = response.into_body().collect().await?.to_bytes(); + let User { + uuid, name, email, .. + } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(user.uuid, uuid); + assert_eq!(user.name, name); + assert_eq!(user.email, email); + Ok(()) } #[sqlx::test] - async fn test_user_register(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); - - let register_user = RegisterSchema { - name: "Ford Prefect".to_string(), - email: "fprefect@heartofgold.galaxy".to_string(), - password: "42".to_string(), + async fn test_user_not_found(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = User { + uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + ..Default::default() + }; + + let request = Request::builder() + .uri(format!("/api/user/{}", user.uuid)) + .body(Body::empty())?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::NOT_FOUND, response.status()); + + Ok(()) + } + + #[sqlx::test] + async fn test_register_created(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = RegisterSchema { + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), + }; + + let request = Request::builder() + .uri("/api/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::CREATED, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(user.name, name); + assert_eq!(user.email, email); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_register_conflict(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = RegisterSchema { + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), + }; + + let request = Request::builder() + .uri("/api/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::CONFLICT, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_login_unauthorized(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = LoginSchema { + email: "adent@earth.sol".to_string(), + password: "hunter2".to_string(), + }; + + let request = Request::builder() + .uri("/api/login") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_login_ok(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = LoginSchema { + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), }; let response = router .oneshot( Request::builder() - .uri("/api/user/register") + .uri("/api/login") .method("POST") .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from( - serde_json::to_vec(&serde_json::json!(register_user)).unwrap(), - ))?, + .body(Body::from(serde_json::to_vec(&user).unwrap()))?, ) .await .unwrap(); - assert_eq!(StatusCode::CREATED, response.status()); - - let body_bytes = response.into_body().collect().await?.to_bytes(); - let user: User = serde_json::from_slice(&body_bytes)?; - - assert_eq!(register_user.name, user.name); - assert_eq!(register_user.email, user.email); + assert_eq!(StatusCode::OK, response.status()); Ok(()) } |