summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-10 20:23:14 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-11 23:51:06 -0500
commit8c56000a3090e0843a1f218a00c3503767658e83 (patch)
treebbcbf4ba4d10468ed8a6e891035ffa4646b77a7c
parenteb8a597d310d8948d0b5a02911dd2002f00cfb39 (diff)
wip: more work on jwt handling
-rw-r--r--Cargo.lock2
-rw-r--r--Cargo.toml3
-rw-r--r--src/config.rs96
-rw-r--r--src/lib.rs24
-rw-r--r--src/main.rs27
-rw-r--r--src/routes.rs22
-rw-r--r--src/routes/jwt.rs339
-rw-r--r--src/routes/login.rs131
-rw-r--r--src/routes/register.rs27
-rw-r--r--src/routes/user.rs68
-rw-r--r--src/state.rs17
11 files changed, 413 insertions, 343 deletions
diff --git a/Cargo.lock b/Cargo.lock
index a154846..7cf36b7 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2025,6 +2025,7 @@ dependencies = [
"pin-project-lite",
"tower-layer",
"tower-service",
+ "tracing",
]
[[package]]
@@ -2153,6 +2154,7 @@ dependencies = [
"jsonwebtoken",
"main_error",
"mime",
+ "once_cell",
"pgtemp",
"serde",
"serde_json",
diff --git a/Cargo.toml b/Cargo.toml
index 4031058..5414596 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,6 +13,7 @@ dotenvy = "0.15.7"
email_address = "0.2.4"
jsonwebtoken = "9.3.0"
main_error = "0.1.2"
+once_cell = "1.19.0"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio", "uuid", "time"] }
@@ -20,7 +21,7 @@ thiserror = "1.0.58"
time = { version = "0.3.34", features = ["serde", "serde-human-readable"] }
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] }
toml = "0.8.12"
-tower-http = { version = "0.5.2", features = ["cors"] }
+tower-http = { version = "0.5.2", features = ["cors", "trace"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["serde", "v4"] }
diff --git a/src/config.rs b/src/config.rs
index d36b8fd..09ec997 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,23 +1,34 @@
-use std::{net::SocketAddr, sync::Arc};
+use std::net::SocketAddr;
-use axum::Router;
use serde::{Deserialize, Serialize};
-use tokio::net::TcpListener;
-use unnamed_server::{state::AppState, Error};
+use unnamed_server::Error;
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone)]
pub struct Config {
- listen_addr: Option<SocketAddr>,
- jwt_secret: Option<String>,
- database_url: Option<String>,
+ pub listen_addr: SocketAddr,
+ pub jwt_secret: String,
+ pub database_url: String,
}
impl Config {
- pub fn new() -> Self {
- Self::default()
+ pub fn builder() -> ConfigBuilder {
+ ConfigBuilder::default()
}
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ConfigBuilder {
+ listen_addr: Option<String>,
+ jwt_secret: Option<String>,
+ database_url: Option<String>,
+}
+
+impl ConfigBuilder {
+ pub fn file(self) -> Result<Self, Error> {
+ let file = std::env::args()
+ .nth(1)
+ .unwrap_or("/etc/unnamed_server.toml".to_string());
- pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> {
match std::fs::read_to_string(file) {
Ok(s) => Ok(self.merge(toml::from_str(&s)?)),
Err(err) => {
@@ -27,33 +38,14 @@ impl Config {
}
}
- pub fn env(self, prefix: &str) -> Result<Self, Error> {
- Ok(self.merge(Self {
- listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR"))
- .ok()
- .and_then(|v| v.parse().ok()),
- jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(),
- database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(),
- }))
- }
+ pub fn env(self) -> Self {
+ let _ = dotenvy::dotenv();
- pub async fn build(self) -> Result<(TcpListener, Router), Error> {
- macro_rules! try_extract {
- ($($i:ident),+) => {
- $(let Some($i) = self.$i else {
- return Err(Error::Config(format!("Missing value: {}", stringify!($i))))
- };)+
- };
- }
-
- try_extract!(listen_addr, jwt_secret, database_url);
-
- let listener = TcpListener::bind(listen_addr).await?;
- let pool = init_db(&database_url).await?;
- let app_state = Arc::new(AppState { pool, jwt_secret });
- let app = unnamed_server::init_router(app_state);
-
- Ok((listener, app))
+ self.merge(Self {
+ listen_addr: std::env::var("LISTEN_ADDR").ok(),
+ jwt_secret: std::env::var("JWT_SECRET").ok(),
+ database_url: std::env::var("DATABASE_URL").ok(),
+ })
}
/// Merge self with other, overwriting any existing values on self with other's.
@@ -64,25 +56,29 @@ impl Config {
database_url: other.database_url.or(self.database_url),
}
}
+
+ pub fn build(self) -> Result<Config, Error> {
+ Ok(Config {
+ listen_addr: self
+ .listen_addr
+ .and_then(|s| s.parse().ok())
+ .ok_or_else(|| Error::Config("listen_addr".to_string()))?,
+ jwt_secret: self
+ .jwt_secret
+ .ok_or_else(|| Error::Config("jwt_secret".to_string()))?,
+ database_url: self
+ .database_url
+ .ok_or_else(|| Error::Config("database_url".to_string()))?,
+ })
+ }
}
-impl Default for Config {
+impl Default for ConfigBuilder {
fn default() -> Self {
Self {
- listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))),
+ listen_addr: Some("127.0.0.1:30000".to_string()),
jwt_secret: None,
database_url: None,
}
}
}
-
-async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> {
- let pool = sqlx::postgres::PgPoolOptions::new()
- .max_connections(10)
- .connect(uri)
- .await?;
-
- sqlx::migrate!().run(&pool).await?;
-
- Ok(pool)
-}
diff --git a/src/lib.rs b/src/lib.rs
index e7502f9..13b05c0 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,3 +5,27 @@ pub mod error;
pub mod model;
pub mod routes;
pub mod state;
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use std::sync::Once;
+
+ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+ pub type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
+
+ pub const JWT_SECRET: &str = "test-jwt-secret-token";
+
+ static INIT: Once = Once::new();
+
+ pub fn setup_test_env() {
+ INIT.call_once(|| {
+ tracing_subscriber::registry()
+ .with(tracing_subscriber::EnvFilter::from_default_env())
+ .with(tracing_subscriber::fmt::layer().with_test_writer())
+ .init();
+
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+ });
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 67aa54a..f8d09f1 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,6 @@
-use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+use tokio::net::TcpListener;
+use tracing_subscriber::{layer::SubscriberExt, registry, util::SubscriberInitExt, EnvFilter};
+use unnamed_server::state::AppState;
use crate::config::Config;
@@ -7,27 +9,18 @@ mod config;
#[tokio::main]
#[tracing::instrument]
async fn main() -> Result<(), main_error::MainError> {
- let _ = dotenvy::dotenv();
-
- tracing_subscriber::registry()
- .with(
- tracing_subscriber::EnvFilter::try_from_default_env()
- .unwrap_or_else(|_| "unnamed_server=debug".into()),
- )
+ registry()
+ .with(EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.init();
- let config_file = std::env::args()
- .nth(1)
- .unwrap_or("/etc/unnamed_server.toml".to_string());
+ let config = Config::builder().file()?.env().build()?;
- let (listener, router) = Config::new()
- .file(config_file)?
- .env("UNNAMED_")?
- .build()
- .await?;
+ let listener = TcpListener::bind(config.listen_addr).await?;
+ let app_state = AppState::new(config.database_url).await?;
+ let router = unnamed_server::init_router(app_state);
- tracing::info!("Server listening on http://{}", listener.local_addr()?);
+ tracing::info!("Listening on http://{}", listener.local_addr()?);
axum::serve(listener, router).await.map_err(From::from)
}
diff --git a/src/routes.rs b/src/routes.rs
index ad00b1e..39a0976 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -1,34 +1,27 @@
-use std::sync::Arc;
-
use axum::{
http::{StatusCode, Uri},
- middleware::map_request_with_state,
response::IntoResponse,
};
use axum_extra::routing::RouterExt;
-use tower_http::cors::CorsLayer;
+use tower_http::{cors::CorsLayer, trace::TraceLayer};
use crate::state::AppState;
-use self::jwt::authenticate;
-
mod healthcheck;
mod jwt;
-mod login;
mod register;
mod user;
#[tracing::instrument]
-pub fn init_router(state: Arc<AppState>) -> axum::Router {
+pub fn init_router(state: AppState) -> axum::Router {
axum::Router::new()
.typed_get(user::User::get)
- .typed_get(login::Logout::get)
- .route_layer(map_request_with_state(state.clone(), authenticate))
.typed_get(healthcheck::HealthCheck::get)
.typed_get(user::UserUuid::get)
.typed_post(register::Register::post)
- .typed_post(login::Login::post)
+ .nest("/api/auth", jwt::init_router(state.clone()))
.layer(CorsLayer::permissive())
+ .layer(TraceLayer::new_for_http())
.fallback(fallback)
.with_state(state)
}
@@ -54,10 +47,9 @@ mod tests {
#[sqlx::test]
async fn test_route_not_found(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let request = Request::builder()
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs
index 6a229a3..ccce13e 100644
--- a/src/routes/jwt.rs
+++ b/src/routes/jwt.rs
@@ -1,114 +1,305 @@
-use std::sync::Arc;
-
+use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::{
- extract::{Request, State},
- response::IntoResponse,
+ async_trait,
+ extract::{FromRequestParts, State},
+ http::{header::SET_COOKIE, request::Parts, HeaderValue},
+ response::{IntoResponse, IntoResponseParts},
+ RequestPartsExt,
};
use axum_extra::{
extract::{cookie::Cookie, CookieJar},
- headers::{authorization::Bearer, Authorization},
- routing::TypedPath,
+ headers::{
+ authorization::{Basic, Bearer},
+ Authorization,
+ },
+ routing::{RouterExt, TypedPath},
TypedHeader,
};
-use jsonwebtoken::{DecodingKey, Validation};
+use jsonwebtoken::{decode, DecodingKey, EncodingKey};
+use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
-use crate::{error::AuthError, state::AppState, Error};
+use crate::{error::AuthError, model::UserSchema, state::AppState, Error};
+
+pub fn init_router(state: AppState) -> axum::Router<AppState> {
+ axum::Router::new()
+ .typed_get(Issue::get)
+ .typed_get(Refresh::get)
+ .with_state(state)
+}
+
+static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| {
+ let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
+ JwtEnv::new(secret.as_bytes())
+});
+
+#[derive(Clone)]
+struct JwtEnv {
+ encoding: EncodingKey,
+ decoding: DecodingKey,
+ header: jsonwebtoken::Header,
+ validation: jsonwebtoken::Validation,
+}
+
+impl JwtEnv {
+ fn new(secret: &[u8]) -> Self {
+ Self {
+ encoding: EncodingKey::from_secret(secret),
+ decoding: DecodingKey::from_secret(secret),
+ header: Default::default(),
+ validation: Default::default(),
+ }
+ }
+}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct Claims {
+pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
pub exp: i64,
pub jti: Uuid,
}
-impl Claims {
- const MAX_AGE: i64 = 3600;
+impl<const LIFETIME: i64> Claims<LIFETIME> {
+ pub fn new(uuid: Uuid) -> Self {
+ let now = OffsetDateTime::now_utc().unix_timestamp();
+ Self {
+ sub: uuid,
+ iat: now,
+ exp: now + LIFETIME,
+ jti: uuid::Uuid::new_v4(),
+ }
+ }
- pub fn new(sub: Uuid) -> Self {
- let iat = OffsetDateTime::now_utc().unix_timestamp();
- let exp = iat + Self::MAX_AGE;
- let jti = uuid::Uuid::new_v4();
- Self { sub, iat, exp, jti }
+ pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
+ jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding)
}
+}
- pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> {
- jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- self,
- &jsonwebtoken::EncodingKey::from_secret(secret),
- )
+impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Ok(Cookie::build(("token", value.encode()?))
+ .expires(OffsetDateTime::from_unix_timestamp(value.exp)?)
+ .secure(true)
+ .http_only(true)
+ .build())
}
}
-impl From<Uuid> for Claims {
- fn from(value: Uuid) -> Self {
- Self::new(value)
+impl<const L: i64> TryFrom<Claims<L>> for HeaderValue {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Cookie::try_from(value)?
+ .encoded()
+ .to_string()
+ .parse()
+ .map_err(Into::into)
}
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-struct Session {
- jti: Uuid,
- uuid: Uuid,
+// 1 day in seconds
+const ACCESS: i64 = 86400;
+
+pub type AccessClaims = Claims<ACCESS>;
+
+impl From<RefreshClaims> for AccessClaims {
+ fn from(value: RefreshClaims) -> Self {
+ Claims::new(value.sub)
+ }
+}
+
+impl IntoResponse for AccessClaims {
+ fn into_response(self) -> axum::response::Response {
+ (self, ()).into_response()
+ }
+}
+
+impl IntoResponseParts for AccessClaims {
+ type Error = Error;
+
+ fn into_response_parts(
+ self,
+ mut res: axum::response::ResponseParts,
+ ) -> Result<axum::response::ResponseParts, Self::Error> {
+ res.headers_mut()
+ .append(SET_COOKIE, HeaderValue::try_from(self)?);
+
+ Ok(res)
+ }
+}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for AccessClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let token = parts
+ .extract::<CookieJar>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?
+ .get("token")
+ .ok_or(AuthError::JwtNotFound)?
+ .to_string();
+
+ decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
+}
+
+// 30 days in seconds
+const REFRESH: i64 = 2_592_000;
+
+pub type RefreshClaims = Claims<REFRESH>;
+
+impl RefreshClaims {
+ pub fn refresh(self) -> AccessClaims {
+ self.into()
+ }
+}
+
+//impl IntoResponse for RefreshClaims {
+// fn into_response(self) -> axum::response::Response {
+// (self.refresh(), self).into_response()
+// }
+//}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for RefreshClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(bearer)) = parts
+ .extract::<TypedHeader<Authorization<Bearer>>>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?;
+
+ decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
}
#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/auth/refresh")]
-pub struct Refresh;
+#[typed_path("/issue")]
+pub struct Issue;
-impl Refresh {
- #[tracing::instrument]
- pub async fn post(
+impl Issue {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(
self,
- State(state): State<Arc<AppState>>,
- TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
- cookie_jar: CookieJar,
+ State(state): State<AppState>,
+ TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>,
) -> Result<impl IntoResponse, Error> {
- let Claims { sub, .. } = jsonwebtoken::decode::<Claims>(
- bearer.token(),
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- let claims = Claims::from(sub);
-
- let token = jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- &claims,
- &jsonwebtoken::EncodingKey::from_secret(state.jwt_secret.as_ref()),
+ let UserSchema {
+ uuid,
+ password_hash,
+ ..
+ } = sqlx::query_as!(
+ UserSchema,
+ "SELECT * FROM users WHERE email = $1 LIMIT 1",
+ basic.username().to_ascii_lowercase()
+ )
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or(AuthError::LoginInvalid)?;
+
+ Argon2::default().verify_password(
+ basic.password().as_bytes(),
+ &PasswordHash::new(&password_hash)?,
)?;
- let cookie = Cookie::build(("token", token))
- .expires(OffsetDateTime::from_unix_timestamp(claims.exp)?)
- .secure(true)
- .http_only(true);
+ let claims = Claims::<REFRESH>::new(uuid);
- Ok(cookie_jar.add(cookie))
+ Ok((claims.refresh(), claims.encode()?))
}
}
-pub async fn authenticate(
- State(state): State<Arc<AppState>>,
- cookie_jar: CookieJar,
- mut req: Request,
-) -> Result<Request, AuthError> {
- let token = cookie_jar
- .get("token")
- .ok_or(AuthError::JwtNotFound)?
- .to_string();
-
- let claims = jsonwebtoken::decode::<Claims>(
- &token,
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- req.extensions_mut().insert(claims);
- Ok(req)
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/refresh")]
+pub struct Refresh;
+
+impl Refresh {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse {
+ claims.refresh()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use axum::{
+ body::Body,
+ http::{header::AUTHORIZATION, Request, StatusCode},
+ };
+ use axum_extra::headers::authorization::Credentials;
+ use sqlx::PgPool;
+ use tower::ServiceExt;
+
+ use crate::{
+ init_router,
+ tests::{setup_test_env, TestResult},
+ };
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "hunter2");
+ tracing::debug!(?auth, "Auth");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(dbg!(request)).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_login_ok(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::OK, response.status());
+
+ Ok(())
+ }
}
diff --git a/src/routes/login.rs b/src/routes/login.rs
deleted file mode 100644
index 0e1e825..0000000
--- a/src/routes/login.rs
+++ /dev/null
@@ -1,131 +0,0 @@
-use std::sync::Arc;
-
-use argon2::{Argon2, PasswordHash, PasswordVerifier};
-use axum::{extract::State, response::IntoResponse, Json};
-use axum_extra::{headers::Authorization, routing::TypedPath, TypedHeader};
-use serde::Deserialize;
-
-use crate::{
- error::AuthError,
- model::{LoginSchema, UserSchema},
- state::AppState,
- Error,
-};
-
-use super::jwt::Claims;
-
-#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/login")]
-pub struct Login;
-
-impl Login {
- #[tracing::instrument(skip(state, password))]
- pub async fn post(
- self,
- State(state): State<Arc<AppState>>,
- Json(LoginSchema { email, password }): Json<LoginSchema>,
- ) -> Result<impl IntoResponse, Error> {
- let UserSchema {
- uuid,
- password_hash,
- ..
- } = sqlx::query_as!(
- UserSchema,
- "SELECT * FROM users WHERE email = $1 LIMIT 1",
- email.to_ascii_lowercase()
- )
- .fetch_optional(&state.pool)
- .await?
- .ok_or(AuthError::LoginInvalid)?;
-
- Argon2::default()
- .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?;
-
- let token = Claims::from(uuid).encode(state.jwt_secret.as_ref())?;
-
- Authorization::bearer(&token)
- .map(TypedHeader)
- .map_err(Into::into)
- }
-}
-
-#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/logout")]
-pub struct Logout;
-
-impl Logout {
- #[tracing::instrument]
- pub async fn get(self) -> impl IntoResponse {
- todo!("Invalidate jwt somehow...");
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use axum::{
- body::Body,
- http::{header, Request, StatusCode},
- };
- use sqlx::PgPool;
- use tower::ServiceExt;
-
- use crate::init_router;
-
- const JWT_SECRET: &str = "test-jwt-secret-token";
-
- type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
-
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_unauthorized(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
-
- let user = LoginSchema {
- email: "adent@earth.sol".to_string(),
- password: "hunter2".to_string(),
- };
-
- let request = Request::builder()
- .uri("/api/login")
- .method("POST")
- .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
- .body(Body::from(serde_json::to_vec(&user)?))?;
-
- let response = router.oneshot(request).await?;
-
- assert_eq!(StatusCode::UNAUTHORIZED, response.status());
-
- Ok(())
- }
-
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
-
- let user = LoginSchema {
- email: "adent@earth.sol".to_string(),
- password: "solongandthanksforallthefish".to_string(),
- };
-
- let request = Request::builder()
- .uri("/api/login")
- .method("POST")
- .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
- .body(Body::from(serde_json::to_vec(&user)?))?;
-
- let response = router.oneshot(request).await?;
-
- assert_eq!(StatusCode::OK, response.status());
-
- Ok(())
- }
-}
diff --git a/src/routes/register.rs b/src/routes/register.rs
index 286e70f..75819b0 100644
--- a/src/routes/register.rs
+++ b/src/routes/register.rs
@@ -1,4 +1,4 @@
-use std::{str::FromStr, sync::Arc};
+use std::str::FromStr;
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
@@ -25,7 +25,7 @@ impl Register {
#[tracing::instrument(skip(password))]
pub async fn post(
self,
- State(state): State<Arc<AppState>>,
+ State(state): State<AppState>,
Json(RegisterSchema {
name,
email,
@@ -73,18 +73,16 @@ mod tests {
use sqlx::PgPool;
use tower::ServiceExt;
- use crate::init_router;
-
- const JWT_SECRET: &str = "test-jwt-secret-token";
-
- type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
+ use crate::{
+ init_router,
+ tests::{setup_test_env, TestResult},
+ };
#[sqlx::test]
async fn test_register_created(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ setup_test_env();
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = RegisterSchema {
@@ -114,10 +112,9 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_register_conflict(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ setup_test_env();
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = RegisterSchema {
diff --git a/src/routes/user.rs b/src/routes/user.rs
index 73eef04..31cd5cb 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use axum::{extract::State, response::IntoResponse, Extension, Json};
use axum_extra::routing::TypedPath;
use serde::Deserialize;
@@ -17,7 +15,7 @@ pub struct UserUuid {
impl UserUuid {
/// Get a user with a specific `uuid`
#[tracing::instrument]
- pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
+ pub async fn get(self, State(state): State<AppState>) -> impl IntoResponse {
sqlx::query_as!(
UserSchema,
"SELECT * FROM users WHERE uuid = $1 LIMIT 1",
@@ -38,7 +36,7 @@ impl User {
#[tracing::instrument]
pub async fn get(
self,
- State(state): State<Arc<AppState>>,
+ State(state): State<AppState>,
Extension(Claims { sub, .. }): Extension<Claims>,
) -> Result<impl IntoResponse, Error> {
sqlx::query_as!(
@@ -59,14 +57,14 @@ mod tests {
use axum::{
body::Body,
- http::{header::AUTHORIZATION, Request, StatusCode},
+ http::{header::COOKIE, HeaderValue, Request, StatusCode},
};
use http_body_util::BodyExt;
use sqlx::PgPool;
use tower::ServiceExt;
- use crate::{init_router, model::UserSchema};
+ use crate::{init_router, model::UserSchema, routes::jwt::AccessClaims};
const JWT_SECRET: &str = "test-jwt-secret-token";
const UUID: uuid::Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145");
@@ -75,10 +73,9 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_uuid_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = UserSchema {
@@ -110,10 +107,9 @@ mod tests {
#[sqlx::test]
async fn test_user_uuid_not_found(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = UserSchema {
@@ -136,17 +132,17 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
- let token = Claims::from(UUID).encode(JWT_SECRET.as_ref())?;
+ let state = AppState { pool };
+ let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, format!("Bearer {token}"))
+ .header(
+ COOKIE,
+ HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
+ )
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -167,17 +163,17 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
- let token = Claims::from(UUID).encode("BAD_SECRET".as_ref())?;
+ let state = AppState { pool };
+ let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, format!("Bearer {token}"))
+ .header(
+ COOKIE,
+ HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
+ )
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -189,15 +185,14 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, "Bearer invalidtoken")
+ .header(COOKIE, "token=sadfasdfsdfs")
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -209,10 +204,9 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_missing_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let request = Request::builder().uri("/api/user").body(Body::empty())?;
diff --git a/src/state.rs b/src/state.rs
index 22234f3..2646489 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -12,12 +12,23 @@ use crate::Error;
#[derive(Debug, Clone)]
pub struct AppState {
pub pool: Pool<Postgres>,
- pub jwt_secret: String,
}
impl AppState {
- pub fn new(pool: Pool<Postgres>, jwt_secret: String) -> Self {
- Self { pool, jwt_secret }
+ #[tracing::instrument]
+ pub async fn new(uri: String) -> Result<Self, Error> {
+ tracing::debug!("Attempting to connect to database...");
+
+ let pool = sqlx::postgres::PgPoolOptions::new()
+ .max_connections(10)
+ .connect(&uri)
+ .await?;
+
+ tracing::info!("Connected to database");
+
+ sqlx::migrate!().run(&pool).await?;
+
+ Ok(Self { pool })
}
}