summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-13 17:37:08 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-13 17:37:08 -0500
commit49266fab5d12a63ea51708941ac15c286dfc4141 (patch)
treef61de885004d5658739db60fd8a374e7a216f03e /src
parentfecfd74eb29a5e1ddafab48d393c022dfcac3815 (diff)
fix(api,auth): improve token validation and errors
Diffstat (limited to 'src')
-rw-r--r--src/api/error.rs26
-rw-r--r--src/api/users.rs10
-rw-r--r--src/auth/claims.rs31
-rw-r--r--src/auth/error.rs3
4 files changed, 65 insertions, 5 deletions
diff --git a/src/api/error.rs b/src/api/error.rs
index bd43ce3..9048c20 100644
--- a/src/api/error.rs
+++ b/src/api/error.rs
@@ -9,6 +9,12 @@ pub enum Error {
#[error("User not found")]
UserNotFound,
+ #[error("Required header not found: {0}")]
+ HeaderNotFound(axum::http::HeaderName),
+
+ #[error("Failed to parse header: {0} (wrong token type?)")]
+ Header(axum_extra::typed_header::TypedHeaderRejection),
+
#[error("Invalid user token")]
InvalidToken,
@@ -21,19 +27,35 @@ pub enum Error {
#[error("Failed to reach authentication server: {0}")]
AuthRequest(#[from] axum::http::Error),
+ #[error("Not authorization values found")]
+ Unauthorized,
+
#[error("Authentication error: {0}")]
Auth(#[from] crate::auth::error::Error),
}
+impl From<axum_extra::typed_header::TypedHeaderRejection> for Error {
+ fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self {
+ if value.is_missing() {
+ Self::HeaderNotFound(value.name().clone())
+ } else {
+ Self::Header(value)
+ }
+ }
+}
+
impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
+ use axum::http::header::AUTHORIZATION;
use axum::http::StatusCode;
let status = match self {
Self::RouteNotFound(_) | Self::UserNotFound => StatusCode::NOT_FOUND,
Self::EmailExists => StatusCode::CONFLICT,
- Self::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY,
- Self::InvalidToken => StatusCode::UNAUTHORIZED,
+ Self::InvalidToken | Self::Unauthorized => StatusCode::UNAUTHORIZED,
+ Self::HeaderNotFound(ref h) if h == AUTHORIZATION => StatusCode::UNAUTHORIZED,
+ Self::HeaderNotFound(_) => StatusCode::BAD_REQUEST,
+ Self::EmailInvalid(_) | Self::Header(_) => StatusCode::UNPROCESSABLE_ENTITY,
Self::AuthRequest(_) | Self::Sqlx(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Auth(err) => return err.into_response(),
};
diff --git a/src/api/users.rs b/src/api/users.rs
index 0cac406..6eb2a39 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -9,6 +9,7 @@ use axum::{
use axum_extra::{
headers::{authorization::Basic, Authorization},
routing::Resource,
+ typed_header::TypedHeaderRejection,
TypedHeader,
};
use serde::{Deserialize, Serialize};
@@ -48,8 +49,15 @@ pub struct RegisterSchema {
pub async fn login(
State(state): State<AppState>,
- TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>,
+ auth: Result<TypedHeader<Authorization<Basic>>, TypedHeaderRejection>,
+ claims: Option<RefreshClaims>,
) -> Result<(AccessClaims, RefreshClaims), Error> {
+ if let Some(refresh_claims) = claims {
+ return Ok((refresh_claims.refresh(), refresh_claims));
+ }
+
+ let TypedHeader(Authorization(basic)) = auth?;
+
let user_id = sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username())
.fetch_optional(&state.pool)
.await?
diff --git a/src/auth/claims.rs b/src/auth/claims.rs
index ff582a3..bee1c35 100644
--- a/src/auth/claims.rs
+++ b/src/auth/claims.rs
@@ -14,13 +14,14 @@ use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
use uuid::Uuid;
use super::{Error, JWT};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(remote = "Self")]
pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
@@ -40,6 +41,32 @@ impl<const LIFETIME: i64> Claims<LIFETIME> {
}
}
+impl<const LIFETIME: i64> Serialize for Claims<LIFETIME> {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ Self::serialize(self, serializer)
+ }
+}
+
+impl<'de, const LIFETIME: i64> Deserialize<'de> for Claims<LIFETIME> {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let claims = Self::deserialize(deserializer)?;
+
+ if claims.exp - claims.iat != LIFETIME {
+ return Err(serde::de::Error::custom(
+ "Lifetime is invalid for Claim type",
+ ));
+ }
+
+ Ok(claims)
+ }
+}
+
// 1 day in seconds
const ACCESS: i64 = 86400;
@@ -158,6 +185,6 @@ where
.await
.map_err(|_| Error::JwtNotFound)?;
- Ok(JWT.decode(bearer.token())?.claims)
+ JWT.decode(bearer.token()).map(|jwt| jwt.claims)
}
}
diff --git a/src/auth/error.rs b/src/auth/error.rs
index 17cf6d1..8b1bb4c 100644
--- a/src/auth/error.rs
+++ b/src/auth/error.rs
@@ -33,6 +33,9 @@ pub enum Error {
#[error("Authorization token not found")]
JwtNotFound,
+ #[error("Token found was invalid type")]
+ InvalidTokenType,
+
#[error("The user belonging to this token no longer exists")]
UserNotFound,
}