summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock183
-rw-r--r--Cargo.toml5
-rw-r--r--fixtures/users.sql6
-rw-r--r--migrations/20240321225523_init.down.sql2
-rw-r--r--migrations/20240321225523_init.up.sql6
-rw-r--r--src/error.rs28
-rw-r--r--src/lib.rs2
-rw-r--r--src/main.rs25
-rw-r--r--src/model.rs51
-rw-r--r--src/routes.rs297
-rw-r--r--src/state.rs19
11 files changed, 511 insertions, 113 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 3194485..d6f7678 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -140,16 +140,18 @@ dependencies = [
[[package]]
name = "axum-extra"
-version = "0.9.2"
+version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "895ff42f72016617773af68fb90da2a9677d89c62338ec09162d4909d86fdd8f"
+checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733"
dependencies = [
"axum",
"axum-core",
"axum-macros",
"bytes",
+ "cookie",
"form_urlencoded",
"futures-util",
+ "headers",
"http",
"http-body",
"http-body-util",
@@ -161,6 +163,7 @@ dependencies = [
"tower",
"tower-layer",
"tower-service",
+ "tracing",
]
[[package]]
@@ -236,6 +239,12 @@ dependencies = [
]
[[package]]
+name = "bumpalo"
+version = "3.15.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
+
+[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -266,6 +275,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
+name = "cookie"
+version = "0.18.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
+dependencies = [
+ "percent-encoding",
+ "time",
+ "version_check",
+]
+
+[[package]]
name = "cpufeatures"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -531,8 +551,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
+ "js-sys",
"libc",
"wasi",
+ "wasm-bindgen",
]
[[package]]
@@ -580,6 +602,30 @@ dependencies = [
]
[[package]]
+name = "headers"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
+dependencies = [
+ "base64",
+ "bytes",
+ "headers-core",
+ "http",
+ "httpdate",
+ "mime",
+ "sha1",
+]
+
+[[package]]
+name = "headers-core"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
+dependencies = [
+ "http",
+]
+
+[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -745,6 +791,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
+name = "js-sys"
+version = "0.3.69"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
+dependencies = [
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "jsonwebtoken"
+version = "9.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f"
+dependencies = [
+ "base64",
+ "js-sys",
+ "pem",
+ "ring",
+ "serde",
+ "serde_json",
+ "simple_asn1",
+]
+
+[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -882,6 +952,17 @@ dependencies = [
]
[[package]]
+name = "num-bigint"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
+dependencies = [
+ "autocfg",
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
name = "num-bigint-dig"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1006,6 +1087,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
+name = "pem"
+version = "3.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310"
+dependencies = [
+ "base64",
+ "serde",
+]
+
+[[package]]
name = "pem-rfc7468"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1205,6 +1296,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
+name = "ring"
+version = "0.17.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
+dependencies = [
+ "cc",
+ "cfg-if",
+ "getrandom",
+ "libc",
+ "spin 0.9.8",
+ "untrusted",
+ "windows-sys 0.52.0",
+]
+
+[[package]]
name = "rsa"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1378,6 +1484,18 @@ dependencies = [
]
[[package]]
+name = "simple_asn1"
+version = "0.6.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085"
+dependencies = [
+ "num-bigint",
+ "num-traits",
+ "thiserror",
+ "time",
+]
+
+[[package]]
name = "slab"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1967,6 +2085,7 @@ dependencies = [
"dotenvy",
"email_address",
"http-body-util",
+ "jsonwebtoken",
"mime",
"pgtemp",
"serde",
@@ -1982,6 +2101,12 @@ dependencies = [
]
[[package]]
+name = "untrusted"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
+
+[[package]]
name = "url"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2038,6 +2163,60 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
[[package]]
+name = "wasm-bindgen"
+version = "0.2.92"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
+dependencies = [
+ "cfg-if",
+ "wasm-bindgen-macro",
+]
+
+[[package]]
+name = "wasm-bindgen-backend"
+version = "0.2.92"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
+dependencies = [
+ "bumpalo",
+ "log",
+ "once_cell",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.52",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-macro"
+version = "0.2.92"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
+dependencies = [
+ "quote",
+ "wasm-bindgen-macro-support",
+]
+
+[[package]]
+name = "wasm-bindgen-macro-support"
+version = "0.2.92"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.52",
+ "wasm-bindgen-backend",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-shared"
+version = "0.2.92"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
+
+[[package]]
name = "whoami"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 975ef91..1ae7a29 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,14 +8,15 @@ edition = "2021"
[dependencies]
argon2 = { version = "0.5.3", features = ["std"] }
axum = "0.7.4"
-axum-extra = { version = "0.9.2", features = ["typed-routing"] }
+axum-extra = { version = "0.9.3", features = ["typed-routing", "cookie", "typed-header"] }
dotenvy = "0.15.7"
email_address = "0.2.4"
+jsonwebtoken = "9.3.0"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio", "uuid", "time"] }
thiserror = "1.0.58"
-time = { version = "0.3.34", features = ["serde"] }
+time = { version = "0.3.34", features = ["serde", "serde-human-readable"] }
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
diff --git a/fixtures/users.sql b/fixtures/users.sql
index fa47f61..70c8689 100644
--- a/fixtures/users.sql
+++ b/fixtures/users.sql
@@ -1,9 +1,11 @@
INSERT INTO users (
+ uuid,
name,
email,
- password
+ password_hash
) VALUES(
+ '4c14f795-86f0-4361-a02f-0edb966fb145',
'Arthur Dent',
'adent@earth.sol',
- 'solongandthanksforallthefish'
+ '$argon2id$v=19$m=19456,t=2,p=1$31LeWXQsq0wwHT0MgAliVA$V6pQ0nKpgcq+nOWT6p4AuyVM0zy/09Ct9XpSPHq3wSo'
);
diff --git a/migrations/20240321225523_init.down.sql b/migrations/20240321225523_init.down.sql
index 15a9a45..ec52e0b 100644
--- a/migrations/20240321225523_init.down.sql
+++ b/migrations/20240321225523_init.down.sql
@@ -1,3 +1 @@
---- Add down migration script here
-
DROP TABLE IF EXISTS "users";
diff --git a/migrations/20240321225523_init.up.sql b/migrations/20240321225523_init.up.sql
index a744b99..7ae6aab 100644
--- a/migrations/20240321225523_init.up.sql
+++ b/migrations/20240321225523_init.up.sql
@@ -1,10 +1,10 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE users (
- id UUID NOT NULL PRIMARY KEY DEFAULT (uuid_generate_v4()),
+ uuid UUID NOT NULL PRIMARY KEY DEFAULT (uuid_generate_v4()),
name VARCHAR(100) NOT NULL,
email VARCHAR(255) NOT NULL UNIQUE,
- password VARCHAR(100) NOT NULL,
+ password_hash VARCHAR(100) NOT NULL,
created_at TIMESTAMP
WITH
TIME ZONE DEFAULT NOW(),
@@ -12,5 +12,3 @@ CREATE TABLE users (
WITH
TIME ZONE DEFAULT NOW()
);
-
-CREATE INDEX users_email_idx ON users (email);
diff --git a/src/error.rs b/src/error.rs
index 351c01a..2824e49 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -20,6 +20,9 @@ pub enum Error {
#[error("Json error: {0}")]
Json(#[from] serde_json::Error),
+ #[error("JWT error: {0}")]
+ JWT(#[from] jsonwebtoken::errors::Error),
+
#[error("Database error: {0}")]
Sqlx(#[from] sqlx::Error),
@@ -27,7 +30,7 @@ pub enum Error {
Migration(#[from] sqlx::migrate::MigrateError),
#[error("Failed to hash password: {0}")]
- PasswordHash(#[from] argon2::password_hash::Error),
+ PasswordHash(#[source] argon2::password_hash::Error),
#[error("User not found")]
UserNotFound,
@@ -45,12 +48,22 @@ pub enum Error {
Other(String),
}
+impl From<argon2::password_hash::Error> for Error {
+ fn from(value: argon2::password_hash::Error) -> Self {
+ match value {
+ argon2::password_hash::Error::Password => Self::LoginInvalid,
+ _ => Self::PasswordHash(value),
+ }
+ }
+}
+
impl From<&Error> for StatusCode {
fn from(value: &Error) -> Self {
match value {
Error::UserNotFound => StatusCode::NOT_FOUND,
Error::EmailExists => StatusCode::CONFLICT,
Error::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY,
+ Error::LoginInvalid => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
@@ -60,10 +73,13 @@ impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
// TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html)
- Json(json!({
- "status": StatusCode::from(&self).to_string(),
- "detail": self.to_string(),
- }))
- .into_response()
+ (
+ StatusCode::from(&self),
+ Json(json!({
+ "status": StatusCode::from(&self).to_string(),
+ "detail": self.to_string(),
+ })),
+ )
+ .into_response()
}
}
diff --git a/src/lib.rs b/src/lib.rs
index 231c1c1..e7502f9 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,5 +1,5 @@
pub use error::{Error, Result};
-pub use routes::router;
+pub use routes::init_router;
pub mod error;
pub mod model;
diff --git a/src/main.rs b/src/main.rs
index 1edf738..a926916 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,5 +1,6 @@
use std::sync::Arc;
+use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use unnamed_server::{state::AppState, Error};
@@ -15,12 +16,21 @@ async fn main() -> Result<(), Error> {
.with(tracing_subscriber::fmt::layer())
.init();
+ // TODO: Migrate all of these into a struct parsed from env, cli, and file.
let _ = dotenvy::dotenv();
let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string());
+ let jwt_max_age: time::Duration = time::Duration::HOUR;
+ // serde_json::from_str(&std::env::var("JWT_MAX_AGE").unwrap_or_else(|_| "1h".to_string()))?;
+ let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET is not set");
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set");
- let state = Arc::new(AppState::init(&database_url).await?);
- let app = unnamed_server::router(state);
+ let pool = init_db(&database_url).await?;
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret,
+ jwt_max_age,
+ });
+ let app = unnamed_server::init_router(state);
let listener = TcpListener::bind(listen_addr).await?;
@@ -28,3 +38,14 @@ async fn main() -> Result<(), Error> {
axum::serve(listener, app).await.map_err(From::from)
}
+
+async fn init_db(uri: &str) -> Result<Pool<Postgres>, sqlx::Error> {
+ let pool = PgPoolOptions::new()
+ .max_connections(10)
+ .connect(uri)
+ .await?;
+
+ sqlx::migrate!().run(&pool).await?;
+
+ Ok(pool)
+}
diff --git a/src/model.rs b/src/model.rs
index 51ce493..395cdd1 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -3,33 +3,41 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use time::OffsetDateTime;
+use uuid::Uuid;
use crate::Error;
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)]
+#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)]
#[serde(rename_all = "camelCase")]
pub struct User {
- pub id: uuid::Uuid,
+ pub uuid: Uuid,
pub name: String,
pub email: String,
#[serde(default, skip_serializing)]
- pub password: String,
+ pub password_hash: String,
pub created_at: Option<OffsetDateTime>,
pub updated_at: Option<OffsetDateTime>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TokenClaims {
- pub sub: String,
- pub iat: usize,
- pub exp: usize,
+ pub sub: Uuid,
+ pub exp: i64,
+}
+
+impl TokenClaims {
+ pub fn new(sub: Uuid, max_age: time::Duration) -> Self {
+ Self {
+ sub,
+ exp: (time::OffsetDateTime::now_utc() + max_age).unix_timestamp(),
+ }
+ }
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RegisterSchema {
pub name: String,
pub email: String,
- #[serde(default, skip_serializing)]
pub password: String,
}
@@ -43,27 +51,14 @@ impl RegisterSchema {
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginSchema {
pub email: String,
- #[serde(default, skip_serializing)]
pub password: String,
}
-macro_rules! impl_from_superset {
- ($from:tt, $to:ty, $($field:tt)*) => {
- impl From<$from> for $to {
- fn from(value: $from) -> Self {
- let $from {
- $($field)*,
- ..
- } = value;
-
- Self {
- $($field)*,
- }
- }
- }
- };
+impl From<RegisterSchema> for LoginSchema {
+ fn from(value: RegisterSchema) -> Self {
+ let RegisterSchema {
+ email, password, ..
+ } = value;
+ Self { email, password }
+ }
}
-
-impl_from_superset!(User, RegisterSchema, name, email, password);
-impl_from_superset!(User, LoginSchema, email, password);
-impl_from_superset!(RegisterSchema, LoginSchema, email, password);
diff --git a/src/routes.rs b/src/routes.rs
index 2692f1a..1ec4e30 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -1,31 +1,36 @@
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
- Argon2, PasswordHasher,
+ Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
};
use axum::{
extract::State,
- http::{StatusCode, Uri},
+ http::{header::SET_COOKIE, StatusCode, Uri},
response::IntoResponse,
Json,
};
-use axum_extra::routing::{RouterExt, TypedPath};
+use axum_extra::{
+ extract::cookie::{Cookie, SameSite},
+ routing::{RouterExt, TypedPath},
+};
+use jsonwebtoken::{EncodingKey, Header};
use serde::Deserialize;
use crate::{
- model::{RegisterSchema, User},
+ model::{LoginSchema, RegisterSchema, TokenClaims, User},
state::AppState,
Error,
};
#[tracing::instrument]
-pub fn router(state: Arc<AppState>) -> axum::Router {
+pub fn init_router(state: Arc<AppState>) -> axum::Router {
axum::Router::new()
// .route("/api/user", get(get_user))
.typed_get(HealthCheck::get)
.typed_get(UserUuid::get)
.typed_post(Register::post)
+ .typed_post(Login::post)
.fallback(fallback)
.with_state(state)
}
@@ -58,7 +63,7 @@ impl UserUuid {
/// Get a user with a specific `uuid`
#[tracing::instrument]
pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
- sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid)
+ sqlx::query_as!(User, "SELECT * FROM users WHERE uuid = $1", self.uuid)
.fetch_optional(&state.pool)
.await?
.ok_or_else(|| Error::UserNotFound)
@@ -67,21 +72,25 @@ impl UserUuid {
}
#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/user/register")]
+#[typed_path("/api/register")]
pub struct Register;
impl Register {
- #[tracing::instrument(skip(register_schema))]
+ #[tracing::instrument(skip(password))]
pub async fn post(
self,
State(state): State<Arc<AppState>>,
- Json(register_schema): Json<RegisterSchema>,
+ Json(RegisterSchema {
+ name,
+ email,
+ password,
+ }): Json<RegisterSchema>,
) -> impl IntoResponse {
- register_schema.validate()?;
+ email_address::EmailAddress::from_str(&email)?;
let exists: Option<bool> =
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)")
- .bind(register_schema.email.to_ascii_lowercase())
+ .bind(email.to_ascii_lowercase())
.fetch_one(&state.pool)
.await?;
@@ -90,15 +99,14 @@ impl Register {
}
let salt = SaltString::generate(&mut OsRng);
- let hashed_password =
- Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?;
+ let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?;
let user = sqlx::query_as!(
User,
- "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *",
- register_schema.name,
- register_schema.email.to_ascii_lowercase(),
- hashed_password.to_string()
+ "INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *",
+ name,
+ email.to_ascii_lowercase(),
+ password_hash.to_string()
)
.fetch_one(&state.pool)
.await?;
@@ -107,6 +115,56 @@ impl Register {
}
}
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/api/login")]
+pub struct Login;
+
+impl Login {
+ #[tracing::instrument(skip(state, password))]
+ pub async fn post(
+ self,
+ State(state): State<Arc<AppState>>,
+ Json(LoginSchema { email, password }): Json<LoginSchema>,
+ ) -> Result<impl IntoResponse, Error> {
+ let User {
+ uuid,
+ password_hash,
+ ..
+ } = sqlx::query_as!(
+ User,
+ "SELECT * FROM users WHERE email = $1",
+ email.to_ascii_lowercase()
+ )
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or(Error::LoginInvalid)?;
+
+ Argon2::default()
+ .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?;
+
+ let token = jsonwebtoken::encode(
+ &Header::default(),
+ &TokenClaims::new(uuid, state.jwt_max_age),
+ &EncodingKey::from_secret(state.jwt_secret.as_ref()),
+ )?;
+
+ let cookie = Cookie::build(("token", token.to_owned()))
+ .path("/")
+ .max_age(state.jwt_max_age)
+ .same_site(SameSite::Lax)
+ .http_only(true)
+ .build();
+
+ let mut response = Json(token).into_response();
+
+ response
+ .headers_mut()
+ .insert(SET_COOKIE, cookie.to_string().parse().unwrap());
+
+ Ok(response)
+ }
+}
+
pub async fn fallback(uri: Uri) -> impl IntoResponse {
(StatusCode::NOT_FOUND, format!("Route not found: {uri}"))
}
@@ -123,10 +181,17 @@ mod tests {
use sqlx::PgPool;
use tower::ServiceExt;
+ const JWT_SECRET: &str = "test-jwt-secret-token";
+ const JWT_MAX_AGE: time::Duration = time::Duration::HOUR;
+
#[sqlx::test]
- async fn test_fallback(pool: PgPool) -> Result<(), Error> {
- let state = Arc::new(AppState { pool });
- let router = router(state.clone());
+ async fn test_route_not_found(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
let response = router
.oneshot(
@@ -144,59 +209,183 @@ mod tests {
}
#[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
- async fn test_user(pool: PgPool) -> Result<(), Error> {
- let state = Arc::new(AppState { pool });
- let router = router(state.clone());
+ async fn test_user_ok(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
- let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1")
- .fetch_one(&state.pool)
- .await?;
+ let user = User {
+ uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"),
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ ..Default::default()
+ };
- let response = router
- .oneshot(
- Request::builder()
- .uri(format!("/api/user/{}", user.id))
- .body(Body::empty())?,
- )
- .await
- .unwrap();
+ let request = Request::builder()
+ .uri(format!("/api/user/{}", user.uuid))
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await.unwrap();
assert_eq!(StatusCode::OK, response.status());
+ let body_bytes = response.into_body().collect().await?.to_bytes();
+ let User {
+ uuid, name, email, ..
+ } = serde_json::from_slice(&body_bytes)?;
+
+ assert_eq!(user.uuid, uuid);
+ assert_eq!(user.name, name);
+ assert_eq!(user.email, email);
+
Ok(())
}
#[sqlx::test]
- async fn test_user_register(pool: PgPool) -> Result<(), Error> {
- let state = Arc::new(AppState { pool });
- let router = router(state.clone());
-
- let register_user = RegisterSchema {
- name: "Ford Prefect".to_string(),
- email: "fprefect@heartofgold.galaxy".to_string(),
- password: "42".to_string(),
+ async fn test_user_not_found(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let user = User {
+ uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"),
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ ..Default::default()
+ };
+
+ let request = Request::builder()
+ .uri(format!("/api/user/{}", user.uuid))
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await.unwrap();
+
+ assert_eq!(StatusCode::NOT_FOUND, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test]
+ async fn test_register_created(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let user = RegisterSchema {
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ password: "solongandthanksforallthefish".to_string(),
+ };
+
+ let request = Request::builder()
+ .uri("/api/register")
+ .method("POST")
+ .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(Body::from(serde_json::to_vec(&user).unwrap()))?;
+
+ let response = router.oneshot(request).await.unwrap();
+
+ assert_eq!(StatusCode::CREATED, response.status());
+
+ let body_bytes = response.into_body().collect().await?.to_bytes();
+ let User { name, email, .. } = serde_json::from_slice(&body_bytes)?;
+
+ assert_eq!(user.name, name);
+ assert_eq!(user.email, email);
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
+ async fn test_register_conflict(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let user = RegisterSchema {
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ password: "solongandthanksforallthefish".to_string(),
+ };
+
+ let request = Request::builder()
+ .uri("/api/register")
+ .method("POST")
+ .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(Body::from(serde_json::to_vec(&user).unwrap()))?;
+
+ let response = router.oneshot(request).await.unwrap();
+
+ assert_eq!(StatusCode::CONFLICT, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
+ async fn test_login_unauthorized(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let user = LoginSchema {
+ email: "adent@earth.sol".to_string(),
+ password: "hunter2".to_string(),
+ };
+
+ let request = Request::builder()
+ .uri("/api/login")
+ .method("POST")
+ .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(Body::from(serde_json::to_vec(&user).unwrap()))?;
+
+ let response = router.oneshot(request).await.unwrap();
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
+ async fn test_login_ok(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let user = LoginSchema {
+ email: "adent@earth.sol".to_string(),
+ password: "solongandthanksforallthefish".to_string(),
};
let response = router
.oneshot(
Request::builder()
- .uri("/api/user/register")
+ .uri("/api/login")
.method("POST")
.header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
- .body(Body::from(
- serde_json::to_vec(&serde_json::json!(register_user)).unwrap(),
- ))?,
+ .body(Body::from(serde_json::to_vec(&user).unwrap()))?,
)
.await
.unwrap();
- assert_eq!(StatusCode::CREATED, response.status());
-
- let body_bytes = response.into_body().collect().await?.to_bytes();
- let user: User = serde_json::from_slice(&body_bytes)?;
-
- assert_eq!(register_user.name, user.name);
- assert_eq!(register_user.email, user.email);
+ assert_eq!(StatusCode::OK, response.status());
Ok(())
}
diff --git a/src/state.rs b/src/state.rs
index 614688b..508aaa4 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,19 +1,18 @@
-use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
+use sqlx::{Pool, Postgres};
#[derive(Debug)]
pub struct AppState {
pub pool: Pool<Postgres>,
+ pub jwt_secret: String,
+ pub jwt_max_age: time::Duration,
}
impl AppState {
- pub async fn init(database_uri: &str) -> Result<Self, sqlx::Error> {
- let pool = PgPoolOptions::new()
- .max_connections(10)
- .connect(database_uri)
- .await?;
-
- sqlx::migrate!().run(&pool).await?;
-
- Ok(Self { pool })
+ pub fn new(pool: Pool<Postgres>, jwt_secret: String, jwt_max_age: time::Duration) -> Self {
+ Self {
+ pool,
+ jwt_secret,
+ jwt_max_age,
+ }
}
}