diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-10 20:23:14 -0500 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-11 23:51:06 -0500 |
commit | 8c56000a3090e0843a1f218a00c3503767658e83 (patch) | |
tree | bbcbf4ba4d10468ed8a6e891035ffa4646b77a7c /src/routes | |
parent | eb8a597d310d8948d0b5a02911dd2002f00cfb39 (diff) |
wip: more work on jwt handling
Diffstat (limited to 'src/routes')
-rw-r--r-- | src/routes/jwt.rs | 339 | ||||
-rw-r--r-- | src/routes/login.rs | 131 | ||||
-rw-r--r-- | src/routes/register.rs | 27 | ||||
-rw-r--r-- | src/routes/user.rs | 68 |
4 files changed, 308 insertions, 257 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs index 6a229a3..ccce13e 100644 --- a/src/routes/jwt.rs +++ b/src/routes/jwt.rs @@ -1,114 +1,305 @@ -use std::sync::Arc; - +use argon2::{Argon2, PasswordHash, PasswordVerifier}; use axum::{ - extract::{Request, State}, - response::IntoResponse, + async_trait, + extract::{FromRequestParts, State}, + http::{header::SET_COOKIE, request::Parts, HeaderValue}, + response::{IntoResponse, IntoResponseParts}, + RequestPartsExt, }; use axum_extra::{ extract::{cookie::Cookie, CookieJar}, - headers::{authorization::Bearer, Authorization}, - routing::TypedPath, + headers::{ + authorization::{Basic, Bearer}, + Authorization, + }, + routing::{RouterExt, TypedPath}, TypedHeader, }; -use jsonwebtoken::{DecodingKey, Validation}; +use jsonwebtoken::{decode, DecodingKey, EncodingKey}; +use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; -use crate::{error::AuthError, state::AppState, Error}; +use crate::{error::AuthError, model::UserSchema, state::AppState, Error}; + +pub fn init_router(state: AppState) -> axum::Router<AppState> { + axum::Router::new() + .typed_get(Issue::get) + .typed_get(Refresh::get) + .with_state(state) +} + +static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| { + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + JwtEnv::new(secret.as_bytes()) +}); + +#[derive(Clone)] +struct JwtEnv { + encoding: EncodingKey, + decoding: DecodingKey, + header: jsonwebtoken::Header, + validation: jsonwebtoken::Validation, +} + +impl JwtEnv { + fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), + header: Default::default(), + validation: Default::default(), + } + } +} #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct Claims { +pub struct Claims<const LIFETIME: i64 = ACCESS> { pub sub: Uuid, pub iat: i64, pub exp: i64, pub jti: Uuid, } -impl Claims { - const MAX_AGE: i64 = 3600; +impl<const LIFETIME: i64> Claims<LIFETIME> { + pub fn new(uuid: Uuid) -> Self { + let now = OffsetDateTime::now_utc().unix_timestamp(); + Self { + sub: uuid, + iat: now, + exp: now + LIFETIME, + jti: uuid::Uuid::new_v4(), + } + } - pub fn new(sub: Uuid) -> Self { - let iat = OffsetDateTime::now_utc().unix_timestamp(); - let exp = iat + Self::MAX_AGE; - let jti = uuid::Uuid::new_v4(); - Self { sub, iat, exp, jti } + pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> { + jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding) } +} - pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> { - jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - self, - &jsonwebtoken::EncodingKey::from_secret(secret), - ) +impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> { + type Error = Error; + + fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { + Ok(Cookie::build(("token", value.encode()?)) + .expires(OffsetDateTime::from_unix_timestamp(value.exp)?) + .secure(true) + .http_only(true) + .build()) } } -impl From<Uuid> for Claims { - fn from(value: Uuid) -> Self { - Self::new(value) +impl<const L: i64> TryFrom<Claims<L>> for HeaderValue { + type Error = Error; + + fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { + Cookie::try_from(value)? + .encoded() + .to_string() + .parse() + .map_err(Into::into) } } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -struct Session { - jti: Uuid, - uuid: Uuid, +// 1 day in seconds +const ACCESS: i64 = 86400; + +pub type AccessClaims = Claims<ACCESS>; + +impl From<RefreshClaims> for AccessClaims { + fn from(value: RefreshClaims) -> Self { + Claims::new(value.sub) + } +} + +impl IntoResponse for AccessClaims { + fn into_response(self) -> axum::response::Response { + (self, ()).into_response() + } +} + +impl IntoResponseParts for AccessClaims { + type Error = Error; + + fn into_response_parts( + self, + mut res: axum::response::ResponseParts, + ) -> Result<axum::response::ResponseParts, Self::Error> { + res.headers_mut() + .append(SET_COOKIE, HeaderValue::try_from(self)?); + + Ok(res) + } +} + +#[async_trait] +impl<S> FromRequestParts<S> for AccessClaims +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let token = parts + .extract::<CookieJar>() + .await + .map_err(|_| AuthError::JwtNotFound)? + .get("token") + .ok_or(AuthError::JwtNotFound)? + .to_string(); + + decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation) + .map(|d| d.claims) + .map_err(Into::into) + } +} + +// 30 days in seconds +const REFRESH: i64 = 2_592_000; + +pub type RefreshClaims = Claims<REFRESH>; + +impl RefreshClaims { + pub fn refresh(self) -> AccessClaims { + self.into() + } +} + +//impl IntoResponse for RefreshClaims { +// fn into_response(self) -> axum::response::Response { +// (self.refresh(), self).into_response() +// } +//} + +#[async_trait] +impl<S> FromRequestParts<S> for RefreshClaims +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let TypedHeader(Authorization(bearer)) = parts + .extract::<TypedHeader<Authorization<Bearer>>>() + .await + .map_err(|_| AuthError::JwtNotFound)?; + + decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation) + .map(|d| d.claims) + .map_err(Into::into) + } } #[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/auth/refresh")] -pub struct Refresh; +#[typed_path("/issue")] +pub struct Issue; -impl Refresh { - #[tracing::instrument] - pub async fn post( +impl Issue { + #[tracing::instrument(skip_all)] + pub async fn get( self, - State(state): State<Arc<AppState>>, - TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, - cookie_jar: CookieJar, + State(state): State<AppState>, + TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>, ) -> Result<impl IntoResponse, Error> { - let Claims { sub, .. } = jsonwebtoken::decode::<Claims>( - bearer.token(), - &DecodingKey::from_secret(state.jwt_secret.as_ref()), - &Validation::default(), - )? - .claims; - - let claims = Claims::from(sub); - - let token = jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - &claims, - &jsonwebtoken::EncodingKey::from_secret(state.jwt_secret.as_ref()), + let UserSchema { + uuid, + password_hash, + .. + } = sqlx::query_as!( + UserSchema, + "SELECT * FROM users WHERE email = $1 LIMIT 1", + basic.username().to_ascii_lowercase() + ) + .fetch_optional(&state.pool) + .await? + .ok_or(AuthError::LoginInvalid)?; + + Argon2::default().verify_password( + basic.password().as_bytes(), + &PasswordHash::new(&password_hash)?, )?; - let cookie = Cookie::build(("token", token)) - .expires(OffsetDateTime::from_unix_timestamp(claims.exp)?) - .secure(true) - .http_only(true); + let claims = Claims::<REFRESH>::new(uuid); - Ok(cookie_jar.add(cookie)) + Ok((claims.refresh(), claims.encode()?)) } } -pub async fn authenticate( - State(state): State<Arc<AppState>>, - cookie_jar: CookieJar, - mut req: Request, -) -> Result<Request, AuthError> { - let token = cookie_jar - .get("token") - .ok_or(AuthError::JwtNotFound)? - .to_string(); - - let claims = jsonwebtoken::decode::<Claims>( - &token, - &DecodingKey::from_secret(state.jwt_secret.as_ref()), - &Validation::default(), - )? - .claims; - - req.extensions_mut().insert(claims); - Ok(req) +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/refresh")] +pub struct Refresh; + +impl Refresh { + #[tracing::instrument(skip_all)] + pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse { + claims.refresh() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use axum::{ + body::Body, + http::{header::AUTHORIZATION, Request, StatusCode}, + }; + use axum_extra::headers::authorization::Credentials; + use sqlx::PgPool; + use tower::ServiceExt; + + use crate::{ + init_router, + tests::{setup_test_env, TestResult}, + }; + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_issue_unauthorized(pool: PgPool) -> TestResult { + setup_test_env(); + + let state = AppState { pool }; + let router = init_router(state.clone()); + + let auth = Authorization::basic("adent@earth.sol", "hunter2"); + tracing::debug!(?auth, "Auth"); + + let request = Request::builder() + .uri("/api/auth/issue") + .method("GET") + .header(AUTHORIZATION, auth.0.encode()) + .body(Body::empty())?; + + let response = router.oneshot(dbg!(request)).await?; + + tracing::error!(?response); + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_login_ok(pool: PgPool) -> TestResult { + setup_test_env(); + + let state = AppState { pool }; + let router = init_router(state.clone()); + + let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish"); + + let request = Request::builder() + .uri("/api/auth/issue") + .method("GET") + .header(AUTHORIZATION, auth.0.encode()) + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + tracing::error!(?response); + + assert_eq!(StatusCode::OK, response.status()); + + Ok(()) + } } diff --git a/src/routes/login.rs b/src/routes/login.rs deleted file mode 100644 index 0e1e825..0000000 --- a/src/routes/login.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::sync::Arc; - -use argon2::{Argon2, PasswordHash, PasswordVerifier}; -use axum::{extract::State, response::IntoResponse, Json}; -use axum_extra::{headers::Authorization, routing::TypedPath, TypedHeader}; -use serde::Deserialize; - -use crate::{ - error::AuthError, - model::{LoginSchema, UserSchema}, - state::AppState, - Error, -}; - -use super::jwt::Claims; - -#[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/login")] -pub struct Login; - -impl Login { - #[tracing::instrument(skip(state, password))] - pub async fn post( - self, - State(state): State<Arc<AppState>>, - Json(LoginSchema { email, password }): Json<LoginSchema>, - ) -> Result<impl IntoResponse, Error> { - let UserSchema { - uuid, - password_hash, - .. - } = sqlx::query_as!( - UserSchema, - "SELECT * FROM users WHERE email = $1 LIMIT 1", - email.to_ascii_lowercase() - ) - .fetch_optional(&state.pool) - .await? - .ok_or(AuthError::LoginInvalid)?; - - Argon2::default() - .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?; - - let token = Claims::from(uuid).encode(state.jwt_secret.as_ref())?; - - Authorization::bearer(&token) - .map(TypedHeader) - .map_err(Into::into) - } -} - -#[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/logout")] -pub struct Logout; - -impl Logout { - #[tracing::instrument] - pub async fn get(self) -> impl IntoResponse { - todo!("Invalidate jwt somehow..."); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use axum::{ - body::Body, - http::{header, Request, StatusCode}, - }; - use sqlx::PgPool; - use tower::ServiceExt; - - use crate::init_router; - - const JWT_SECRET: &str = "test-jwt-secret-token"; - - type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>; - - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_login_unauthorized(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); - let router = init_router(state.clone()); - - let user = LoginSchema { - email: "adent@earth.sol".to_string(), - password: "hunter2".to_string(), - }; - - let request = Request::builder() - .uri("/api/login") - .method("POST") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from(serde_json::to_vec(&user)?))?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::UNAUTHORIZED, response.status()); - - Ok(()) - } - - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_login_ok(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); - let router = init_router(state.clone()); - - let user = LoginSchema { - email: "adent@earth.sol".to_string(), - password: "solongandthanksforallthefish".to_string(), - }; - - let request = Request::builder() - .uri("/api/login") - .method("POST") - .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Body::from(serde_json::to_vec(&user)?))?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::OK, response.status()); - - Ok(()) - } -} diff --git a/src/routes/register.rs b/src/routes/register.rs index 286e70f..75819b0 100644 --- a/src/routes/register.rs +++ b/src/routes/register.rs @@ -1,4 +1,4 @@ -use std::{str::FromStr, sync::Arc}; +use std::str::FromStr; use argon2::{ password_hash::{rand_core::OsRng, SaltString}, @@ -25,7 +25,7 @@ impl Register { #[tracing::instrument(skip(password))] pub async fn post( self, - State(state): State<Arc<AppState>>, + State(state): State<AppState>, Json(RegisterSchema { name, email, @@ -73,18 +73,16 @@ mod tests { use sqlx::PgPool; use tower::ServiceExt; - use crate::init_router; - - const JWT_SECRET: &str = "test-jwt-secret-token"; - - type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>; + use crate::{ + init_router, + tests::{setup_test_env, TestResult}, + }; #[sqlx::test] async fn test_register_created(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + setup_test_env(); + + let state = AppState { pool }; let router = init_router(state.clone()); let user = RegisterSchema { @@ -114,10 +112,9 @@ mod tests { #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_register_conflict(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + setup_test_env(); + + let state = AppState { pool }; let router = init_router(state.clone()); let user = RegisterSchema { diff --git a/src/routes/user.rs b/src/routes/user.rs index 73eef04..31cd5cb 100644 --- a/src/routes/user.rs +++ b/src/routes/user.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use axum::{extract::State, response::IntoResponse, Extension, Json}; use axum_extra::routing::TypedPath; use serde::Deserialize; @@ -17,7 +15,7 @@ pub struct UserUuid { impl UserUuid { /// Get a user with a specific `uuid` #[tracing::instrument] - pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse { + pub async fn get(self, State(state): State<AppState>) -> impl IntoResponse { sqlx::query_as!( UserSchema, "SELECT * FROM users WHERE uuid = $1 LIMIT 1", @@ -38,7 +36,7 @@ impl User { #[tracing::instrument] pub async fn get( self, - State(state): State<Arc<AppState>>, + State(state): State<AppState>, Extension(Claims { sub, .. }): Extension<Claims>, ) -> Result<impl IntoResponse, Error> { sqlx::query_as!( @@ -59,14 +57,14 @@ mod tests { use axum::{ body::Body, - http::{header::AUTHORIZATION, Request, StatusCode}, + http::{header::COOKIE, HeaderValue, Request, StatusCode}, }; use http_body_util::BodyExt; use sqlx::PgPool; use tower::ServiceExt; - use crate::{init_router, model::UserSchema}; + use crate::{init_router, model::UserSchema, routes::jwt::AccessClaims}; const JWT_SECRET: &str = "test-jwt-secret-token"; const UUID: uuid::Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"); @@ -75,10 +73,9 @@ mod tests { #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_user_uuid_ok(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + std::env::set_var("JWT_SECRET", JWT_SECRET); + + let state = AppState { pool }; let router = init_router(state.clone()); let user = UserSchema { @@ -110,10 +107,9 @@ mod tests { #[sqlx::test] async fn test_user_uuid_not_found(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + std::env::set_var("JWT_SECRET", JWT_SECRET); + + let state = AppState { pool }; let router = init_router(state.clone()); let user = UserSchema { @@ -136,17 +132,17 @@ mod tests { #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_user_ok(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); - let router = init_router(state.clone()); + std::env::set_var("JWT_SECRET", JWT_SECRET); - let token = Claims::from(UUID).encode(JWT_SECRET.as_ref())?; + let state = AppState { pool }; + let router = init_router(state.clone()); let request = Request::builder() .uri("/api/user") - .header(AUTHORIZATION, format!("Bearer {token}")) + .header( + COOKIE, + HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?, + ) .body(Body::empty())?; let response = router.oneshot(request).await?; @@ -167,17 +163,17 @@ mod tests { #[sqlx::test] async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); - let router = init_router(state.clone()); + std::env::set_var("JWT_SECRET", JWT_SECRET); - let token = Claims::from(UUID).encode("BAD_SECRET".as_ref())?; + let state = AppState { pool }; + let router = init_router(state.clone()); let request = Request::builder() .uri("/api/user") - .header(AUTHORIZATION, format!("Bearer {token}")) + .header( + COOKIE, + HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?, + ) .body(Body::empty())?; let response = router.oneshot(request).await?; @@ -189,15 +185,14 @@ mod tests { #[sqlx::test] async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + std::env::set_var("JWT_SECRET", JWT_SECRET); + + let state = AppState { pool }; let router = init_router(state.clone()); let request = Request::builder() .uri("/api/user") - .header(AUTHORIZATION, "Bearer invalidtoken") + .header(COOKIE, "token=sadfasdfsdfs") .body(Body::empty())?; let response = router.oneshot(request).await?; @@ -209,10 +204,9 @@ mod tests { #[sqlx::test] async fn test_user_unauthorized_missing_token(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - }); + std::env::set_var("JWT_SECRET", JWT_SECRET); + + let state = AppState { pool }; let router = init_router(state.clone()); let request = Request::builder().uri("/api/user").body(Body::empty())?; |