diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/routes/jwt.rs | 122 | ||||
-rw-r--r-- | src/routes/user.rs | 28 |
2 files changed, 83 insertions, 67 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs index ccce13e..902b494 100644 --- a/src/routes/jwt.rs +++ b/src/routes/jwt.rs @@ -15,9 +15,9 @@ use axum_extra::{ routing::{RouterExt, TypedPath}, TypedHeader, }; -use jsonwebtoken::{decode, DecodingKey, EncodingKey}; +use jsonwebtoken::{DecodingKey, EncodingKey, TokenData}; use once_cell::sync::Lazy; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; @@ -35,7 +35,6 @@ static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| { JwtEnv::new(secret.as_bytes()) }); -#[derive(Clone)] struct JwtEnv { encoding: EncodingKey, decoding: DecodingKey, @@ -52,9 +51,23 @@ impl JwtEnv { validation: Default::default(), } } + + pub fn encode<T>(&self, claims: &T) -> Result<String, AuthError> + where + T: Serialize, + { + jsonwebtoken::encode(&self.header, claims, &self.encoding).map_err(Into::into) + } + + pub fn decode<T>(&self, token: &str) -> Result<TokenData<T>, AuthError> + where + T: DeserializeOwned, + { + jsonwebtoken::decode(token, &self.decoding, &self.validation).map_err(Into::into) + } } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub struct Claims<const LIFETIME: i64 = ACCESS> { pub sub: Uuid, pub iat: i64, @@ -72,17 +85,24 @@ impl<const LIFETIME: i64> Claims<LIFETIME> { jti: uuid::Uuid::new_v4(), } } +} - pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> { - jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding) +// 1 day in seconds +const ACCESS: i64 = 86400; + +pub type AccessClaims = Claims<ACCESS>; + +impl From<RefreshClaims> for AccessClaims { + fn from(value: RefreshClaims) -> Self { + Claims::new(value.sub) } } -impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> { +impl TryFrom<AccessClaims> for Cookie<'_> { type Error = Error; - fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { - Ok(Cookie::build(("token", value.encode()?)) + fn try_from(value: AccessClaims) -> Result<Self, Self::Error> { + Ok(Cookie::build(("token", JWT_ENV.encode(&value)?)) .expires(OffsetDateTime::from_unix_timestamp(value.exp)?) .secure(true) .http_only(true) @@ -90,29 +110,17 @@ impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> { } } -impl<const L: i64> TryFrom<Claims<L>> for HeaderValue { +impl TryFrom<AccessClaims> for HeaderValue { type Error = Error; - fn try_from(value: Claims<L>) -> Result<Self, Self::Error> { + fn try_from(value: AccessClaims) -> Result<Self, Self::Error> { Cookie::try_from(value)? - .encoded() .to_string() .parse() .map_err(Into::into) } } -// 1 day in seconds -const ACCESS: i64 = 86400; - -pub type AccessClaims = Claims<ACCESS>; - -impl From<RefreshClaims> for AccessClaims { - fn from(value: RefreshClaims) -> Self { - Claims::new(value.sub) - } -} - impl IntoResponse for AccessClaims { fn into_response(self) -> axum::response::Response { (self, ()).into_response() @@ -140,18 +148,15 @@ where { type Rejection = AuthError; - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { - let token = parts + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { + let jar = parts .extract::<CookieJar>() .await - .map_err(|_| AuthError::JwtNotFound)? - .get("token") - .ok_or(AuthError::JwtNotFound)? - .to_string(); + .expect("Infallable result was in fact, fallable"); - decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation) - .map(|d| d.claims) - .map_err(Into::into) + JWT_ENV + .decode(jar.get("token").ok_or(AuthError::JwtNotFound)?.value()) + .map(|t| t.claims) } } @@ -166,11 +171,14 @@ impl RefreshClaims { } } -//impl IntoResponse for RefreshClaims { -// fn into_response(self) -> axum::response::Response { -// (self.refresh(), self).into_response() -// } -//} +impl IntoResponse for RefreshClaims { + fn into_response(self) -> axum::response::Response { + match JWT_ENV.encode(&self) { + Ok(token) => token.into_response(), + Err(err) => Error::from(err).into_response(), + } + } +} #[async_trait] impl<S> FromRequestParts<S> for RefreshClaims @@ -179,15 +187,13 @@ where { type Rejection = AuthError; - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { let TypedHeader(Authorization(bearer)) = parts .extract::<TypedHeader<Authorization<Bearer>>>() .await .map_err(|_| AuthError::JwtNotFound)?; - decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation) - .map(|d| d.claims) - .map_err(Into::into) + Ok(JWT_ENV.decode(bearer.token())?.claims) } } @@ -196,7 +202,6 @@ where pub struct Issue; impl Issue { - #[tracing::instrument(skip_all)] pub async fn get( self, State(state): State<AppState>, @@ -222,7 +227,7 @@ impl Issue { let claims = Claims::<REFRESH>::new(uuid); - Ok((claims.refresh(), claims.encode()?)) + Ok((claims.refresh(), claims)) } } @@ -231,7 +236,6 @@ impl Issue { pub struct Refresh; impl Refresh { - #[tracing::instrument(skip_all)] pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse { claims.refresh() } @@ -254,15 +258,25 @@ mod tests { tests::{setup_test_env, TestResult}, }; + #[test] + fn test_jwt_encode_decode() -> TestResult { + setup_test_env(); + + let claims = AccessClaims::new(uuid::Uuid::new_v4()); + let token = JWT_ENV.encode(&claims)?; + let decoded = JWT_ENV.decode(&token)?.claims; + assert_eq!(claims, decoded); + Ok(()) + } + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_issue_unauthorized(pool: PgPool) -> TestResult { + async fn test_issue_ok(pool: PgPool) -> TestResult { setup_test_env(); let state = AppState { pool }; let router = init_router(state.clone()); - let auth = Authorization::basic("adent@earth.sol", "hunter2"); - tracing::debug!(?auth, "Auth"); + let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish"); let request = Request::builder() .uri("/api/auth/issue") @@ -270,23 +284,23 @@ mod tests { .header(AUTHORIZATION, auth.0.encode()) .body(Body::empty())?; - let response = router.oneshot(dbg!(request)).await?; - - tracing::error!(?response); + let response = router.oneshot(request).await?; + println!("{response:?}"); - assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + assert_eq!(StatusCode::OK, response.status()); Ok(()) } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_login_ok(pool: PgPool) -> TestResult { + async fn test_issue_unauthorized(pool: PgPool) -> TestResult { setup_test_env(); let state = AppState { pool }; let router = init_router(state.clone()); - let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish"); + let auth = Authorization::basic("adent@earth.sol", "hunter2"); + tracing::debug!(?auth, "Auth"); let request = Request::builder() .uri("/api/auth/issue") @@ -298,7 +312,7 @@ mod tests { tracing::error!(?response); - assert_eq!(StatusCode::OK, response.status()); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); Ok(()) } diff --git a/src/routes/user.rs b/src/routes/user.rs index 31cd5cb..d6dd0da 100644 --- a/src/routes/user.rs +++ b/src/routes/user.rs @@ -1,10 +1,10 @@ -use axum::{extract::State, response::IntoResponse, Extension, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use axum_extra::routing::TypedPath; use serde::Deserialize; use crate::{model::UserSchema, state::AppState, Error}; -use super::jwt::Claims; +use super::jwt::AccessClaims; #[derive(Debug, Deserialize, TypedPath)] #[typed_path("/api/user/:uuid")] @@ -37,7 +37,7 @@ impl User { pub async fn get( self, State(state): State<AppState>, - Extension(Claims { sub, .. }): Extension<Claims>, + AccessClaims { sub, .. }: AccessClaims, ) -> Result<impl IntoResponse, Error> { sqlx::query_as!( UserSchema, @@ -73,8 +73,6 @@ mod tests { #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] async fn test_user_uuid_ok(pool: PgPool) -> TestResult { - std::env::set_var("JWT_SECRET", JWT_SECRET); - let state = AppState { pool }; let router = init_router(state.clone()); @@ -130,19 +128,23 @@ mod tests { Ok(()) } - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + #[test_log::test(sqlx::test(fixtures(path = "../../fixtures", scripts("users"))))] async fn test_user_ok(pool: PgPool) -> TestResult { std::env::set_var("JWT_SECRET", JWT_SECRET); let state = AppState { pool }; let router = init_router(state.clone()); + let user = UserSchema { + uuid: UUID, + name: "Arthur Dent".to_string(), + email: "adent@earth.sol".to_string(), + ..Default::default() + }; + let request = Request::builder() .uri("/api/user") - .header( - COOKIE, - HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?, - ) + .header(COOKIE, HeaderValue::try_from(AccessClaims::new(user.uuid))?) .body(Body::empty())?; let response = router.oneshot(request).await?; @@ -162,7 +164,7 @@ mod tests { } #[sqlx::test] - async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult { + async fn test_user_unauthorized_invalid_token_signature(pool: PgPool) -> TestResult { std::env::set_var("JWT_SECRET", JWT_SECRET); let state = AppState { pool }; @@ -184,7 +186,7 @@ mod tests { } #[sqlx::test] - async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult { + async fn test_user_unauthorized_invalid_token_format(pool: PgPool) -> TestResult { std::env::set_var("JWT_SECRET", JWT_SECRET); let state = AppState { pool }; @@ -213,7 +215,7 @@ mod tests { let response = router.oneshot(request).await?; - assert_eq!(StatusCode::BAD_REQUEST, response.status()); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); Ok(()) } |