summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-16 20:28:46 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-16 20:28:46 -0500
commita2860a1294b250402114fa016c4639881abc2172 (patch)
tree96609af1fe197ec0186f0f409ccfbb8b1012951d
parent917293785bffdd64d467e7d69b5645099b21d5e9 (diff)
refactor: improve account extractors
-rw-r--r--src/api/account.rs64
-rw-r--r--src/api/users.rs19
-rw-r--r--src/auth.rs40
-rw-r--r--src/auth/error.rs26
4 files changed, 106 insertions, 43 deletions
diff --git a/src/api/account.rs b/src/api/account.rs
index d6a94b5..0087df7 100644
--- a/src/api/account.rs
+++ b/src/api/account.rs
@@ -1,13 +1,19 @@
-use axum::{extract::State, routing::get, Router};
+use axum::{
+ async_trait,
+ extract::{FromRequestParts, State},
+ http::request::Parts,
+ routing::get,
+ RequestPartsExt, Router,
+};
use axum_extra::{
+ either::Either,
extract::{cookie::Cookie, CookieJar},
headers::{authorization::Basic, Authorization},
- typed_header::TypedHeaderRejection,
TypedHeader,
};
use crate::{
- auth::{AccessClaims, RefreshClaims},
+ auth::{AccessClaims, Account, RefreshClaims},
state::AppState,
};
@@ -21,32 +27,44 @@ pub fn router() -> Router<AppState> {
pub async fn login(
State(state): State<AppState>,
- auth: Result<TypedHeader<Authorization<Basic>>, TypedHeaderRejection>,
- claims: Option<RefreshClaims>,
-) -> Result<(AccessClaims, RefreshClaims), Error> {
- if let Some(refresh_claims) = claims {
- return Ok((refresh_claims.refresh(), refresh_claims));
+ auth: Either<RefreshClaims, Login>,
+) -> Result<(AccessClaims, RefreshClaims), crate::auth::error::Error> {
+ match auth {
+ Either::E1(token) => Ok((token.refresh(), token)),
+ Either::E2(Login(account)) => crate::auth::issue(State(state.clone()), account).await,
}
-
- let TypedHeader(Authorization(basic)) = auth?;
-
- let user_id = sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username())
- .fetch_optional(&state.pool)
- .await?
- .ok_or(Error::UserNotFound)?;
-
- crate::auth::issue(
- State(state.clone()),
- TypedHeader(Authorization::basic(&user_id.to_string(), basic.password())),
- )
- .await
- .map_err(Into::into)
}
pub async fn logout(claims: AccessClaims, jar: CookieJar) -> Result<CookieJar, Error> {
Ok(jar.remove(Cookie::try_from(claims)?))
}
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Login(Account);
+
+#[async_trait]
+impl FromRequestParts<AppState> for Login {
+ type Rejection = Error;
+
+ async fn from_request_parts(
+ parts: &mut Parts,
+ state: &AppState,
+ ) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(basic)) =
+ parts.extract::<TypedHeader<Authorization<Basic>>>().await?;
+
+ sqlx::query_scalar!("SELECT id FROM user_ WHERE email = $1", basic.username())
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or(Error::UserNotFound)
+ .map(|id| Account {
+ id,
+ password: basic.password().to_string(),
+ })
+ .map(Self)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -60,7 +78,7 @@ mod tests {
Router,
};
- use axum_extra::headers::authorization::Credentials;
+ use axum_extra::headers::{authorization::Credentials, Authorization};
use http_body_util::BodyExt;
use sqlx::PgPool;
use tower::ServiceExt;
diff --git a/src/api/users.rs b/src/api/users.rs
index c8a390d..6ac0bb8 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -11,7 +11,10 @@ use sqlx::FromRow;
use time::OffsetDateTime;
use uuid::Uuid;
-use crate::{auth::AccessClaims, state::AppState};
+use crate::{
+ auth::{credentials::Credential, AccessClaims},
+ state::AppState,
+};
use super::error::Error;
@@ -30,7 +33,7 @@ pub struct User {
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
-pub struct RegisterSchema {
+pub struct Registration {
pub name: String,
pub email: String,
pub password: String,
@@ -38,11 +41,11 @@ pub struct RegisterSchema {
pub async fn create(
State(state): State<AppState>,
- Json(RegisterSchema {
+ Json(Registration {
name,
email,
password,
- }): Json<RegisterSchema>,
+ }): Json<Registration>,
) -> impl IntoResponse {
email_address::EmailAddress::from_str(&email)?;
@@ -58,11 +61,9 @@ pub async fn create(
}
// TODO: Move this into a micro service, possibly behind a feature flag.
- let (status, (access, refresh)) = crate::auth::credentials::create(
- State(state.clone()),
- Json(crate::auth::credentials::Credential { password }),
- )
- .await?;
+ let (status, (access, refresh)) =
+ crate::auth::credentials::create(State(state.clone()), Json(Credential { password }))
+ .await?;
let user = sqlx::query_as!(
User,
diff --git a/src/auth.rs b/src/auth.rs
index 09494fb..6bf0ddf 100644
--- a/src/auth.rs
+++ b/src/auth.rs
@@ -1,5 +1,10 @@
use argon2::{Argon2, PasswordHash, PasswordVerifier};
-use axum::extract::State;
+use axum::{
+ async_trait,
+ extract::{FromRequestParts, State},
+ http::request::Parts,
+ RequestPartsExt,
+};
use axum_extra::{
headers::{authorization::Basic, Authorization},
TypedHeader,
@@ -28,18 +33,16 @@ pub fn router() -> axum::Router<AppState> {
pub async fn issue(
State(state): State<AppState>,
- TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>,
+ Account { id, password }: Account,
) -> Result<(AccessClaims, RefreshClaims), Error> {
- let uuid = Uuid::try_parse(basic.username())?;
-
- let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", uuid)
+ let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", id)
.fetch_optional(&state.pool)
.await?
.ok_or(Error::LoginInvalid)?;
- Argon2::default().verify_password(basic.password().as_bytes(), &PasswordHash::new(&p)?)?;
+ Argon2::default().verify_password(password.as_bytes(), &PasswordHash::new(&p)?)?;
- let refresh = RefreshClaims::issue(uuid);
+ let refresh = RefreshClaims::issue(id);
let access = refresh.refresh();
Ok((access, refresh))
}
@@ -48,6 +51,27 @@ pub async fn refresh(claims: RefreshClaims) -> AccessClaims {
claims.refresh()
}
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Account {
+ pub id: Uuid,
+ pub password: String,
+}
+
+#[async_trait]
+impl FromRequestParts<AppState> for Account {
+ type Rejection = Error;
+
+ async fn from_request_parts(parts: &mut Parts, _: &AppState) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(basic)) =
+ parts.extract::<TypedHeader<Authorization<Basic>>>().await?;
+
+ Ok(Self {
+ id: Uuid::try_parse(basic.username())?,
+ password: basic.password().to_string(),
+ })
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -57,7 +81,7 @@ mod tests {
http::{header::AUTHORIZATION, Request, StatusCode},
Router,
};
- use axum_extra::headers::authorization::Credentials;
+ use axum_extra::headers::{authorization::Credentials, Authorization};
use sqlx::PgPool;
use tower::ServiceExt;
diff --git a/src/auth/error.rs b/src/auth/error.rs
index 17cf6d1..3a111ca 100644
--- a/src/auth/error.rs
+++ b/src/auth/error.rs
@@ -1,7 +1,13 @@
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Failed to parse header: {0}")]
- Header(#[from] axum::http::header::InvalidHeaderValue),
+ HeaderValue(#[from] axum::http::header::InvalidHeaderValue),
+
+ #[error("Required header not found: {0}")]
+ HeaderNotFound(axum::http::HeaderName),
+
+ #[error("Failed to parse header: {0} (wrong token type?)")]
+ HeaderRejection(axum_extra::typed_header::TypedHeaderRejection),
#[error("Database error: {0}")]
Sqlx(#[from] sqlx::Error),
@@ -37,6 +43,16 @@ pub enum Error {
UserNotFound,
}
+impl From<axum_extra::typed_header::TypedHeaderRejection> for Error {
+ fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self {
+ if value.is_missing() {
+ Self::HeaderNotFound(value.name().clone())
+ } else {
+ Self::HeaderRejection(value)
+ }
+ }
+}
+
impl From<argon2::password_hash::Error> for Error {
fn from(value: argon2::password_hash::Error) -> Self {
match value {
@@ -73,10 +89,14 @@ impl From<uuid::Error> for Error {
impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
- use axum::http::StatusCode;
+ use axum::http::{header::AUTHORIZATION, StatusCode};
let status = match self {
- Error::JwtFormat(_) | Error::Uuid(_) => StatusCode::UNPROCESSABLE_ENTITY,
+ Self::HeaderNotFound(ref h) if h == AUTHORIZATION => StatusCode::UNAUTHORIZED,
+ Self::HeaderNotFound(_) => StatusCode::BAD_REQUEST,
+ Self::HeaderRejection(_) | Error::JwtFormat(_) | Error::Uuid(_) => {
+ StatusCode::UNPROCESSABLE_ENTITY
+ }
Error::JwtValidation(_)
| Error::LoginInvalid
| Error::UserNotFound