diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-10 20:23:14 -0500 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-11 23:51:06 -0500 |
commit | 8c56000a3090e0843a1f218a00c3503767658e83 (patch) | |
tree | bbcbf4ba4d10468ed8a6e891035ffa4646b77a7c /src/routes/jwt.rs | |
parent | eb8a597d310d8948d0b5a02911dd2002f00cfb39 (diff) |
wip: more work on jwt handling
Diffstat (limited to 'src/routes/jwt.rs')
-rw-r--r-- | src/routes/jwt.rs | 339 |
1 files changed, 265 insertions, 74 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs index 6a229a3..ccce13e 100644 --- a/src/routes/jwt.rs +++ b/src/routes/jwt.rs @@ -1,114 +1,305 @@ -use std::sync::Arc; - +use argon2::{Argon2, PasswordHash, PasswordVerifier}; use axum::{ - extract::{Request, State}, - response::IntoResponse, + async_trait, + extract::{FromRequestParts, State}, + http::{header::SET_COOKIE, request::Parts, HeaderValue}, + response::{IntoResponse, IntoResponseParts}, + RequestPartsExt, }; use axum_extra::{ extract::{cookie::Cookie, CookieJar}, - headers::{authorization::Bearer, Authorization}, - routing::TypedPath, + headers::{ + authorization::{Basic, Bearer}, + Authorization, + }, + routing::{RouterExt, TypedPath}, TypedHeader, }; -use jsonwebtoken::{DecodingKey, Validation}; +use jsonwebtoken::{decode, DecodingKey, EncodingKey}; +use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; -use crate::{error::AuthError, state::AppState, Error}; +use crate::{error::AuthError, model::UserSchema, state::AppState, Error}; + +pub fn init_router(state: AppState) -> axum::Router<AppState> { + axum::Router::new() + .typed_get(Issue::get) + .typed_get(Refresh::get) + .with_state(state) +} + +static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| { + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + JwtEnv::new(secret.as_bytes()) +}); + +#[derive(Clone)] +struct JwtEnv { + encoding: EncodingKey, + decoding: DecodingKey, + header: jsonwebtoken::Header, + validation: jsonwebtoken::Validation, +} + +impl JwtEnv { + fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), + header: Default::default(), + validation: Default::default(), + } + } +} #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct Claims { +pub struct Claims<const LIFETIME: i64 = ACCESS> { pub sub: Uuid, pub iat: i64, pub exp: i64, pub jti: Uuid, } -impl Claims { - const MAX_AGE: i64 = 3600; +impl<const LIFETIME: i64> Claims<LIFETIME> { + pub fn new(uuid: Uuid) -> Self { + let now = OffsetDateTime::now_utc().unix_timestamp(); + Self { + sub: uuid, + iat: now, + exp: now + LIFETIME, + jti: uuid::Uuid::new_v4(), + } + } - pub fn new(sub: Uuid) -> Self { - let iat = OffsetDateTime::now_utc().unix_timestamp(); - let exp = iat + Self::MAX_AGE; - let jti = uuid::Uuid::new_v4(); - Self { sub, iat, exp, jti } + pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> { + jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding) } +} - pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> { - jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - self, - &jsonwebtoken::EncodingKey::from_secret(secret), - ) +impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> { + type Error = Error; + + fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { + Ok(Cookie::build(("token", value.encode()?)) + .expires(OffsetDateTime::from_unix_timestamp(value.exp)?) + .secure(true) + .http_only(true) + .build()) } } -impl From<Uuid> for Claims { - fn from(value: Uuid) -> Self { - Self::new(value) +impl<const L: i64> TryFrom<Claims<L>> for HeaderValue { + type Error = Error; + + fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { + Cookie::try_from(value)? + .encoded() + .to_string() + .parse() + .map_err(Into::into) } } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -struct Session { - jti: Uuid, - uuid: Uuid, +// 1 day in seconds +const ACCESS: i64 = 86400; + +pub type AccessClaims = Claims<ACCESS>; + +impl From<RefreshClaims> for AccessClaims { + fn from(value: RefreshClaims) -> Self { + Claims::new(value.sub) + } +} + +impl IntoResponse for AccessClaims { + fn into_response(self) -> axum::response::Response { + (self, ()).into_response() + } +} + +impl IntoResponseParts for AccessClaims { + type Error = Error; + + fn into_response_parts( + self, + mut res: axum::response::ResponseParts, + ) -> Result<axum::response::ResponseParts, Self::Error> { + res.headers_mut() + .append(SET_COOKIE, HeaderValue::try_from(self)?); + + Ok(res) + } +} + +#[async_trait] +impl<S> FromRequestParts<S> for AccessClaims +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let token = parts + .extract::<CookieJar>() + .await + .map_err(|_| AuthError::JwtNotFound)? + .get("token") + .ok_or(AuthError::JwtNotFound)? + .to_string(); + + decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation) + .map(|d| d.claims) + .map_err(Into::into) + } +} + +// 30 days in seconds +const REFRESH: i64 = 2_592_000; + +pub type RefreshClaims = Claims<REFRESH>; + +impl RefreshClaims { + pub fn refresh(self) -> AccessClaims { + self.into() + } +} + +//impl IntoResponse for RefreshClaims { +// fn into_response(self) -> axum::response::Response { +// (self.refresh(), self).into_response() +// } +//} + +#[async_trait] +impl<S> FromRequestParts<S> for RefreshClaims +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + let TypedHeader(Authorization(bearer)) = parts + .extract::<TypedHeader<Authorization<Bearer>>>() + .await + .map_err(|_| AuthError::JwtNotFound)?; + + decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation) + .map(|d| d.claims) + .map_err(Into::into) + } } #[derive(Debug, Deserialize, TypedPath)] -#[typed_path("/api/auth/refresh")] -pub struct Refresh; +#[typed_path("/issue")] +pub struct Issue; -impl Refresh { - #[tracing::instrument] - pub async fn post( +impl Issue { + #[tracing::instrument(skip_all)] + pub async fn get( self, - State(state): State<Arc<AppState>>, - TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, - cookie_jar: CookieJar, + State(state): State<AppState>, + TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>, ) -> Result<impl IntoResponse, Error> { - let Claims { sub, .. } = jsonwebtoken::decode::<Claims>( - bearer.token(), - &DecodingKey::from_secret(state.jwt_secret.as_ref()), - &Validation::default(), - )? - .claims; - - let claims = Claims::from(sub); - - let token = jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - &claims, - &jsonwebtoken::EncodingKey::from_secret(state.jwt_secret.as_ref()), + let UserSchema { + uuid, + password_hash, + .. + } = sqlx::query_as!( + UserSchema, + "SELECT * FROM users WHERE email = $1 LIMIT 1", + basic.username().to_ascii_lowercase() + ) + .fetch_optional(&state.pool) + .await? + .ok_or(AuthError::LoginInvalid)?; + + Argon2::default().verify_password( + basic.password().as_bytes(), + &PasswordHash::new(&password_hash)?, )?; - let cookie = Cookie::build(("token", token)) - .expires(OffsetDateTime::from_unix_timestamp(claims.exp)?) - .secure(true) - .http_only(true); + let claims = Claims::<REFRESH>::new(uuid); - Ok(cookie_jar.add(cookie)) + Ok((claims.refresh(), claims.encode()?)) } } -pub async fn authenticate( - State(state): State<Arc<AppState>>, - cookie_jar: CookieJar, - mut req: Request, -) -> Result<Request, AuthError> { - let token = cookie_jar - .get("token") - .ok_or(AuthError::JwtNotFound)? - .to_string(); - - let claims = jsonwebtoken::decode::<Claims>( - &token, - &DecodingKey::from_secret(state.jwt_secret.as_ref()), - &Validation::default(), - )? - .claims; - - req.extensions_mut().insert(claims); - Ok(req) +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/refresh")] +pub struct Refresh; + +impl Refresh { + #[tracing::instrument(skip_all)] + pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse { + claims.refresh() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use axum::{ + body::Body, + http::{header::AUTHORIZATION, Request, StatusCode}, + }; + use axum_extra::headers::authorization::Credentials; + use sqlx::PgPool; + use tower::ServiceExt; + + use crate::{ + init_router, + tests::{setup_test_env, TestResult}, + }; + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_issue_unauthorized(pool: PgPool) -> TestResult { + setup_test_env(); + + let state = AppState { pool }; + let router = init_router(state.clone()); + + let auth = Authorization::basic("adent@earth.sol", "hunter2"); + tracing::debug!(?auth, "Auth"); + + let request = Request::builder() + .uri("/api/auth/issue") + .method("GET") + .header(AUTHORIZATION, auth.0.encode()) + .body(Body::empty())?; + + let response = router.oneshot(dbg!(request)).await?; + + tracing::error!(?response); + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_login_ok(pool: PgPool) -> TestResult { + setup_test_env(); + + let state = AppState { pool }; + let router = init_router(state.clone()); + + let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish"); + + let request = Request::builder() + .uri("/api/auth/issue") + .method("GET") + .header(AUTHORIZATION, auth.0.encode()) + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + tracing::error!(?response); + + assert_eq!(StatusCode::OK, response.status()); + + Ok(()) + } } |