diff options
-rw-r--r-- | src/api/error.rs | 26 | ||||
-rw-r--r-- | src/api/users.rs | 10 | ||||
-rw-r--r-- | src/auth/claims.rs | 31 | ||||
-rw-r--r-- | src/auth/error.rs | 3 |
4 files changed, 65 insertions, 5 deletions
diff --git a/src/api/error.rs b/src/api/error.rs index bd43ce3..9048c20 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -9,6 +9,12 @@ pub enum Error { #[error("User not found")] UserNotFound, + #[error("Required header not found: {0}")] + HeaderNotFound(axum::http::HeaderName), + + #[error("Failed to parse header: {0} (wrong token type?)")] + Header(axum_extra::typed_header::TypedHeaderRejection), + #[error("Invalid user token")] InvalidToken, @@ -21,19 +27,35 @@ pub enum Error { #[error("Failed to reach authentication server: {0}")] AuthRequest(#[from] axum::http::Error), + #[error("Not authorization values found")] + Unauthorized, + #[error("Authentication error: {0}")] Auth(#[from] crate::auth::error::Error), } +impl From<axum_extra::typed_header::TypedHeaderRejection> for Error { + fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self { + if value.is_missing() { + Self::HeaderNotFound(value.name().clone()) + } else { + Self::Header(value) + } + } +} + impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { + use axum::http::header::AUTHORIZATION; use axum::http::StatusCode; let status = match self { Self::RouteNotFound(_) | Self::UserNotFound => StatusCode::NOT_FOUND, Self::EmailExists => StatusCode::CONFLICT, - Self::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY, - Self::InvalidToken => StatusCode::UNAUTHORIZED, + Self::InvalidToken | Self::Unauthorized => StatusCode::UNAUTHORIZED, + Self::HeaderNotFound(ref h) if h == AUTHORIZATION => StatusCode::UNAUTHORIZED, + Self::HeaderNotFound(_) => StatusCode::BAD_REQUEST, + Self::EmailInvalid(_) | Self::Header(_) => StatusCode::UNPROCESSABLE_ENTITY, Self::AuthRequest(_) | Self::Sqlx(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::Auth(err) => return err.into_response(), }; diff --git a/src/api/users.rs b/src/api/users.rs index 0cac406..6eb2a39 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -9,6 +9,7 @@ use axum::{ use axum_extra::{ headers::{authorization::Basic, Authorization}, routing::Resource, + typed_header::TypedHeaderRejection, TypedHeader, }; use serde::{Deserialize, Serialize}; @@ -48,8 +49,15 @@ pub struct RegisterSchema { pub async fn login( State(state): State<AppState>, - TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>, + auth: Result<TypedHeader<Authorization<Basic>>, TypedHeaderRejection>, + claims: Option<RefreshClaims>, ) -> Result<(AccessClaims, RefreshClaims), Error> { + if let Some(refresh_claims) = claims { + return Ok((refresh_claims.refresh(), refresh_claims)); + } + + let TypedHeader(Authorization(basic)) = auth?; + let user_id = sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username()) .fetch_optional(&state.pool) .await? diff --git a/src/auth/claims.rs b/src/auth/claims.rs index ff582a3..bee1c35 100644 --- a/src/auth/claims.rs +++ b/src/auth/claims.rs @@ -14,13 +14,14 @@ use axum_extra::{ headers::{authorization::Bearer, Authorization}, TypedHeader, }; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::OffsetDateTime; use uuid::Uuid; use super::{Error, JWT}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(remote = "Self")] pub struct Claims<const LIFETIME: i64 = ACCESS> { pub sub: Uuid, pub iat: i64, @@ -40,6 +41,32 @@ impl<const LIFETIME: i64> Claims<LIFETIME> { } } +impl<const LIFETIME: i64> Serialize for Claims<LIFETIME> { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + Self::serialize(self, serializer) + } +} + +impl<'de, const LIFETIME: i64> Deserialize<'de> for Claims<LIFETIME> { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + let claims = Self::deserialize(deserializer)?; + + if claims.exp - claims.iat != LIFETIME { + return Err(serde::de::Error::custom( + "Lifetime is invalid for Claim type", + )); + } + + Ok(claims) + } +} + // 1 day in seconds const ACCESS: i64 = 86400; @@ -158,6 +185,6 @@ where .await .map_err(|_| Error::JwtNotFound)?; - Ok(JWT.decode(bearer.token())?.claims) + JWT.decode(bearer.token()).map(|jwt| jwt.claims) } } diff --git a/src/auth/error.rs b/src/auth/error.rs index 17cf6d1..8b1bb4c 100644 --- a/src/auth/error.rs +++ b/src/auth/error.rs @@ -33,6 +33,9 @@ pub enum Error { #[error("Authorization token not found")] JwtNotFound, + #[error("Token found was invalid type")] + InvalidTokenType, + #[error("The user belonging to this token no longer exists")] UserNotFound, } |