diff options
-rw-r--r-- | Cargo.lock | 183 | ||||
-rw-r--r-- | Cargo.toml | 5 | ||||
-rw-r--r-- | fixtures/users.sql | 6 | ||||
-rw-r--r-- | migrations/20240321225523_init.down.sql | 2 | ||||
-rw-r--r-- | migrations/20240321225523_init.up.sql | 6 | ||||
-rw-r--r-- | src/error.rs | 28 | ||||
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/main.rs | 25 | ||||
-rw-r--r-- | src/model.rs | 51 | ||||
-rw-r--r-- | src/routes.rs | 297 | ||||
-rw-r--r-- | src/state.rs | 19 |
11 files changed, 511 insertions, 113 deletions
@@ -140,16 +140,18 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.9.2" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "895ff42f72016617773af68fb90da2a9677d89c62338ec09162d4909d86fdd8f" +checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733" dependencies = [ "axum", "axum-core", "axum-macros", "bytes", + "cookie", "form_urlencoded", "futures-util", + "headers", "http", "http-body", "http-body-util", @@ -161,6 +163,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -236,6 +239,12 @@ dependencies = [ ] [[package]] +name = "bumpalo" +version = "3.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" + +[[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -266,6 +275,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" [[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] name = "cpufeatures" version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -531,8 +551,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -580,6 +602,30 @@ dependencies = [ ] [[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + +[[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -745,6 +791,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "jsonwebtoken" +version = "9.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" +dependencies = [ + "base64", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + +[[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -882,6 +952,17 @@ dependencies = [ ] [[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] name = "num-bigint-dig" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1006,6 +1087,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] +name = "pem" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" +dependencies = [ + "base64", + "serde", +] + +[[package]] name = "pem-rfc7468" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1205,6 +1296,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] name = "rsa" version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1378,6 +1484,18 @@ dependencies = [ ] [[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + +[[package]] name = "slab" version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1967,6 +2085,7 @@ dependencies = [ "dotenvy", "email_address", "http-body-util", + "jsonwebtoken", "mime", "pgtemp", "serde", @@ -1982,6 +2101,12 @@ dependencies = [ ] [[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] name = "url" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2038,6 +2163,60 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.52", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] name = "whoami" version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -8,14 +8,15 @@ edition = "2021" [dependencies] argon2 = { version = "0.5.3", features = ["std"] } axum = "0.7.4" -axum-extra = { version = "0.9.2", features = ["typed-routing"] } +axum-extra = { version = "0.9.3", features = ["typed-routing", "cookie", "typed-header"] } dotenvy = "0.15.7" email_address = "0.2.4" +jsonwebtoken = "9.3.0" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio", "uuid", "time"] } thiserror = "1.0.58" -time = { version = "0.3.34", features = ["serde"] } +time = { version = "0.3.34", features = ["serde", "serde-human-readable"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/fixtures/users.sql b/fixtures/users.sql index fa47f61..70c8689 100644 --- a/fixtures/users.sql +++ b/fixtures/users.sql @@ -1,9 +1,11 @@ INSERT INTO users ( + uuid, name, email, - password + password_hash ) VALUES( + '4c14f795-86f0-4361-a02f-0edb966fb145', 'Arthur Dent', 'adent@earth.sol', - 'solongandthanksforallthefish' + '$argon2id$v=19$m=19456,t=2,p=1$31LeWXQsq0wwHT0MgAliVA$V6pQ0nKpgcq+nOWT6p4AuyVM0zy/09Ct9XpSPHq3wSo' ); diff --git a/migrations/20240321225523_init.down.sql b/migrations/20240321225523_init.down.sql index 15a9a45..ec52e0b 100644 --- a/migrations/20240321225523_init.down.sql +++ b/migrations/20240321225523_init.down.sql @@ -1,3 +1 @@ ---- Add down migration script here - DROP TABLE IF EXISTS "users"; diff --git a/migrations/20240321225523_init.up.sql b/migrations/20240321225523_init.up.sql index a744b99..7ae6aab 100644 --- a/migrations/20240321225523_init.up.sql +++ b/migrations/20240321225523_init.up.sql @@ -1,10 +1,10 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE TABLE users ( - id UUID NOT NULL PRIMARY KEY DEFAULT (uuid_generate_v4()), + uuid UUID NOT NULL PRIMARY KEY DEFAULT (uuid_generate_v4()), name VARCHAR(100) NOT NULL, email VARCHAR(255) NOT NULL UNIQUE, - password VARCHAR(100) NOT NULL, + password_hash VARCHAR(100) NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), @@ -12,5 +12,3 @@ CREATE TABLE users ( WITH TIME ZONE DEFAULT NOW() ); - -CREATE INDEX users_email_idx ON users (email); diff --git a/src/error.rs b/src/error.rs index 351c01a..2824e49 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,6 +20,9 @@ pub enum Error { #[error("Json error: {0}")] Json(#[from] serde_json::Error), + #[error("JWT error: {0}")] + JWT(#[from] jsonwebtoken::errors::Error), + #[error("Database error: {0}")] Sqlx(#[from] sqlx::Error), @@ -27,7 +30,7 @@ pub enum Error { Migration(#[from] sqlx::migrate::MigrateError), #[error("Failed to hash password: {0}")] - PasswordHash(#[from] argon2::password_hash::Error), + PasswordHash(#[source] argon2::password_hash::Error), #[error("User not found")] UserNotFound, @@ -45,12 +48,22 @@ pub enum Error { Other(String), } +impl From<argon2::password_hash::Error> for Error { + fn from(value: argon2::password_hash::Error) -> Self { + match value { + argon2::password_hash::Error::Password => Self::LoginInvalid, + _ => Self::PasswordHash(value), + } + } +} + impl From<&Error> for StatusCode { fn from(value: &Error) -> Self { match value { Error::UserNotFound => StatusCode::NOT_FOUND, Error::EmailExists => StatusCode::CONFLICT, Error::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY, + Error::LoginInvalid => StatusCode::UNAUTHORIZED, _ => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -60,10 +73,13 @@ impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { // TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html) - Json(json!({ - "status": StatusCode::from(&self).to_string(), - "detail": self.to_string(), - })) - .into_response() + ( + StatusCode::from(&self), + Json(json!({ + "status": StatusCode::from(&self).to_string(), + "detail": self.to_string(), + })), + ) + .into_response() } } @@ -1,5 +1,5 @@ pub use error::{Error, Result}; -pub use routes::router; +pub use routes::init_router; pub mod error; pub mod model; diff --git a/src/main.rs b/src/main.rs index 1edf738..a926916 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use unnamed_server::{state::AppState, Error}; @@ -15,12 +16,21 @@ async fn main() -> Result<(), Error> { .with(tracing_subscriber::fmt::layer()) .init(); + // TODO: Migrate all of these into a struct parsed from env, cli, and file. let _ = dotenvy::dotenv(); let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string()); + let jwt_max_age: time::Duration = time::Duration::HOUR; + // serde_json::from_str(&std::env::var("JWT_MAX_AGE").unwrap_or_else(|_| "1h".to_string()))?; + let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET is not set"); let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set"); - let state = Arc::new(AppState::init(&database_url).await?); - let app = unnamed_server::router(state); + let pool = init_db(&database_url).await?; + let state = Arc::new(AppState { + pool, + jwt_secret, + jwt_max_age, + }); + let app = unnamed_server::init_router(state); let listener = TcpListener::bind(listen_addr).await?; @@ -28,3 +38,14 @@ async fn main() -> Result<(), Error> { axum::serve(listener, app).await.map_err(From::from) } + +async fn init_db(uri: &str) -> Result<Pool<Postgres>, sqlx::Error> { + let pool = PgPoolOptions::new() + .max_connections(10) + .connect(uri) + .await?; + + sqlx::migrate!().run(&pool).await?; + + Ok(pool) +} diff --git a/src/model.rs b/src/model.rs index 51ce493..395cdd1 100644 --- a/src/model.rs +++ b/src/model.rs @@ -3,33 +3,41 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; use sqlx::FromRow; use time::OffsetDateTime; +use uuid::Uuid; use crate::Error; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)] #[serde(rename_all = "camelCase")] pub struct User { - pub id: uuid::Uuid, + pub uuid: Uuid, pub name: String, pub email: String, #[serde(default, skip_serializing)] - pub password: String, + pub password_hash: String, pub created_at: Option<OffsetDateTime>, pub updated_at: Option<OffsetDateTime>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct TokenClaims { - pub sub: String, - pub iat: usize, - pub exp: usize, + pub sub: Uuid, + pub exp: i64, +} + +impl TokenClaims { + pub fn new(sub: Uuid, max_age: time::Duration) -> Self { + Self { + sub, + exp: (time::OffsetDateTime::now_utc() + max_age).unix_timestamp(), + } + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RegisterSchema { pub name: String, pub email: String, - #[serde(default, skip_serializing)] pub password: String, } @@ -43,27 +51,14 @@ impl RegisterSchema { #[derive(Debug, Serialize, Deserialize)] pub struct LoginSchema { pub email: String, - #[serde(default, skip_serializing)] pub password: String, } -macro_rules! impl_from_superset { - ($from:tt, $to:ty, $($field:tt)*) => { - impl From<$from> for $to { - fn from(value: $from) -> Self { - let $from { - $($field)*, - .. - } = value; - - Self { - $($field)*, - } - } - } - }; +impl From<RegisterSchema> for LoginSchema { + fn from(value: RegisterSchema) -> Self { + let RegisterSchema { + email, password, .. + } = value; + Self { email, password } + } } - -impl_from_superset!(User, RegisterSchema, name, email, password); -impl_from_superset!(User, LoginSchema, email, password); -impl_from_superset!(RegisterSchema, LoginSchema, email, password); diff --git a/src/routes.rs b/src/routes.rs index 2692f1a..1ec4e30 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,31 +1,36 @@ -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use argon2::{ password_hash::{rand_core::OsRng, SaltString}, - Argon2, PasswordHasher, + Argon2, PasswordHash, PasswordHasher, PasswordVerifier, }; use axum::{ extract::State, - http::{StatusCode, Uri}, + http::{header::SET_COOKIE, StatusCode, Uri}, response::IntoResponse, Json, }; -use axum_extra::routing::{RouterExt, TypedPath}; +use axum_extra::{ + extract::cookie::{Cookie, SameSite}, + routing::{RouterExt, TypedPath}, +}; +use jsonwebtoken::{EncodingKey, Header}; use serde::Deserialize; use crate::{ - model::{RegisterSchema, User}, + model::{LoginSchema, RegisterSchema, TokenClaims, User}, state::AppState, Error, }; #[tracing::instrument] -pub fn router(state: Arc<AppState>) -> axum::Router { +pub fn init_router(state: Arc<AppState>) -> axum::Router { axum::Router::new() // .route("/api/user", get(get_user)) .typed_get(HealthCheck::get) .typed_get(UserUuid::get) .typed_post(Register::post) + .typed_post(Login::post) .fallback(fallback) .with_state(state) } @@ -58,7 +63,7 @@ impl UserUuid { /// Get a user with a specific `uuid` #[tracing::instrument] pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse { - sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid) + sqlx::query_as!(User, "SELECT * FROM users WHERE uuid = $1", self.uuid) .fetch_optional(&state.pool) .await? .ok_or_else(|| Error::UserNotFound) @@ -67,21 +72,25 @@ impl UserUuid { } #[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/user/register")] +#[typed_path("/api/register")] pub struct Register; impl Register { - #[tracing::instrument(skip(register_schema))] + #[tracing::instrument(skip(password))] pub async fn post( self, State(state): State<Arc<AppState>>, - Json(register_schema): Json<RegisterSchema>, + Json(RegisterSchema { + name, + email, + password, + }): Json<RegisterSchema>, ) -> impl IntoResponse { - register_schema.validate()?; + email_address::EmailAddress::from_str(&email)?; let exists: Option<bool> = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)") - .bind(register_schema.email.to_ascii_lowercase()) + .bind(email.to_ascii_lowercase()) .fetch_one(&state.pool) .await?; @@ -90,15 +99,14 @@ impl Register { } let salt = SaltString::generate(&mut OsRng); - let hashed_password = - Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?; + let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?; let user = sqlx::query_as!( User, - "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *", - register_schema.name, - register_schema.email.to_ascii_lowercase(), - hashed_password.to_string() + "INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *", + name, + email.to_ascii_lowercase(), + password_hash.to_string() ) .fetch_one(&state.pool) .await?; @@ -107,6 +115,56 @@ impl Register { } } +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/api/login")] +pub struct Login; + +impl Login { + #[tracing::instrument(skip(state, password))] + pub async fn post( + self, + State(state): State<Arc<AppState>>, + Json(LoginSchema { email, password }): Json<LoginSchema>, + ) -> Result<impl IntoResponse, Error> { + let User { + uuid, + password_hash, + .. + } = sqlx::query_as!( + User, + "SELECT * FROM users WHERE email = $1", + email.to_ascii_lowercase() + ) + .fetch_optional(&state.pool) + .await? + .ok_or(Error::LoginInvalid)?; + + Argon2::default() + .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?; + + let token = jsonwebtoken::encode( + &Header::default(), + &TokenClaims::new(uuid, state.jwt_max_age), + &EncodingKey::from_secret(state.jwt_secret.as_ref()), + )?; + + let cookie = Cookie::build(("token", token.to_owned())) + .path("/") + .max_age(state.jwt_max_age) + .same_site(SameSite::Lax) + .http_only(true) + .build(); + + let mut response = Json(token).into_response(); + + response + .headers_mut() + .insert(SET_COOKIE, cookie.to_string().parse().unwrap()); + + Ok(response) + } +} + pub async fn fallback(uri: Uri) -> impl IntoResponse { (StatusCode::NOT_FOUND, format!("Route not found: {uri}")) } @@ -123,10 +181,17 @@ mod tests { use sqlx::PgPool; use tower::ServiceExt; + const JWT_SECRET: &str = "test-jwt-secret-token"; + const JWT_MAX_AGE: time::Duration = time::Duration::HOUR; + #[sqlx::test] - async fn test_fallback(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); + async fn test_route_not_found(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); let response = router .oneshot( @@ -144,59 +209,183 @@ mod tests { } #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] - async fn test_user(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); + async fn test_user_ok(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); - let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1") - .fetch_one(&state.pool) - .await?; + let user = User { + uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + ..Default::default() + }; - let response = router - .oneshot( - Request::builder() - .uri(format!("/api/user/{}", user.id)) - .body(Body::empty())?, - ) - .await - .unwrap(); + let request = Request::builder() + .uri(format!("/api/user/{}", user.uuid)) + .body(Body::empty())?; + + let response = router.oneshot(request).await.unwrap(); assert_eq!(StatusCode::OK, response.status()); + let body_bytes = response.into_body().collect().await?.to_bytes(); + let User { + uuid, name, email, .. + } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(user.uuid, uuid); + assert_eq!(user.name, name); + assert_eq!(user.email, email); + Ok(()) } #[sqlx::test] - async fn test_user_register(pool: PgPool) -> Result<(), Error> { - let state = Arc::new(AppState { pool }); - let router = router(state.clone()); - - let register_user = RegisterSchema { - name: "Ford Prefect".to_string(), - email: "fprefect@heartofgold.galaxy".to_string(), - password: "42".to_string(), + async fn test_user_not_found(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = User { + uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + ..Default::default() + }; + + let request = Request::builder() + .uri(format!("/api/user/{}", user.uuid)) + .body(Body::empty())?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::NOT_FOUND, response.status()); + + Ok(()) + } + + #[sqlx::test] + async fn test_register_created(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = RegisterSchema { + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), + }; + + let request = Request::builder() + .uri("/api/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::CREATED, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(user.name, name); + assert_eq!(user.email, email); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_register_conflict(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = RegisterSchema { + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), + }; + + let request = Request::builder() + .uri("/api/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::CONFLICT, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_login_unauthorized(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = LoginSchema { + email: "adent@earth.sol".to_string(), + password: "hunter2".to_string(), + }; + + let request = Request::builder() + .uri("/api/login") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_vec(&user).unwrap()))?; + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_login_ok(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let user = LoginSchema { + email: "adent@earth.sol".to_string(), + password: "solongandthanksforallthefish".to_string(), }; let response = router .oneshot( Request::builder() - .uri("/api/user/register") + .uri("/api/login") .method("POST") .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from( - serde_json::to_vec(&serde_json::json!(register_user)).unwrap(), - ))?, + .body(Body::from(serde_json::to_vec(&user).unwrap()))?, ) .await .unwrap(); - assert_eq!(StatusCode::CREATED, response.status()); - - let body_bytes = response.into_body().collect().await?.to_bytes(); - let user: User = serde_json::from_slice(&body_bytes)?; - - assert_eq!(register_user.name, user.name); - assert_eq!(register_user.email, user.email); + assert_eq!(StatusCode::OK, response.status()); Ok(()) } diff --git a/src/state.rs b/src/state.rs index 614688b..508aaa4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,19 +1,18 @@ -use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; +use sqlx::{Pool, Postgres}; #[derive(Debug)] pub struct AppState { pub pool: Pool<Postgres>, + pub jwt_secret: String, + pub jwt_max_age: time::Duration, } impl AppState { - pub async fn init(database_uri: &str) -> Result<Self, sqlx::Error> { - let pool = PgPoolOptions::new() - .max_connections(10) - .connect(database_uri) - .await?; - - sqlx::migrate!().run(&pool).await?; - - Ok(Self { pool }) + pub fn new(pool: Pool<Postgres>, jwt_secret: String, jwt_max_age: time::Duration) -> Self { + Self { + pool, + jwt_secret, + jwt_max_age, + } } } |