summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-06 02:09:26 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-06 02:10:32 -0500
commit9822bc18bb0cb5e13104376ecefc6ec99d93b016 (patch)
tree0137399c270a69a2c09e7c0e4c41f64cad5dbbfe /src
parentf7dd456941dbc5f926a04935d3aaaa198741608e (diff)
feat: impl jwt auth middleware and user route
Diffstat (limited to 'src')
-rw-r--r--src/error.rs59
-rw-r--r--src/jwt.rs51
-rw-r--r--src/lib.rs1
-rw-r--r--src/model.rs17
-rw-r--r--src/routes.rs8
-rw-r--r--src/routes/login.rs80
-rw-r--r--src/routes/register.rs6
-rw-r--r--src/routes/user.rs151
-rw-r--r--src/state.rs24
9 files changed, 272 insertions, 125 deletions
diff --git a/src/error.rs b/src/error.rs
index 6a32438..6414a13 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -26,8 +26,11 @@ pub enum Error {
#[error("Json error: {0}")]
Json(#[from] serde_json::Error),
- #[error("JWT error: {0}")]
- JWT(#[from] jsonwebtoken::errors::Error),
+ #[error("JSON web token error: {0}")]
+ Jwt(#[from] jsonwebtoken::errors::Error),
+
+ #[error("Token error: {0}")]
+ Token(#[from] axum_extra::headers::authorization::InvalidBearerToken),
#[error("Database error: {0}")]
Sqlx(#[from] sqlx::Error),
@@ -48,7 +51,7 @@ pub enum Error {
EmailInvalid(#[from] email_address::Error),
#[error("Invalid email or password")]
- LoginInvalid,
+ Authorization(#[from] AuthError),
#[error("{0}")]
Other(String),
@@ -57,35 +60,55 @@ pub enum Error {
impl From<argon2::password_hash::Error> for Error {
fn from(value: argon2::password_hash::Error) -> Self {
match value {
- argon2::password_hash::Error::Password => Self::LoginInvalid,
+ argon2::password_hash::Error::Password => Self::Authorization(AuthError::LoginInvalid),
_ => Self::PasswordHash(value),
}
}
}
-impl From<&Error> for StatusCode {
- fn from(value: &Error) -> Self {
- match value {
- Error::UserNotFound => StatusCode::NOT_FOUND,
- Error::EmailExists => StatusCode::CONFLICT,
- Error::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY,
- Error::LoginInvalid => StatusCode::UNAUTHORIZED,
- _ => StatusCode::INTERNAL_SERVER_ERROR,
- }
- }
-}
-
impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
// TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html)
+ let status = match &self {
+ Self::UserNotFound => StatusCode::NOT_FOUND,
+ Self::EmailExists => StatusCode::CONFLICT,
+ Self::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY,
+ Self::Authorization(_) => StatusCode::UNAUTHORIZED,
+ _ => StatusCode::INTERNAL_SERVER_ERROR,
+ };
+
(
- StatusCode::from(&self),
+ status,
Json(json!({
- "status": StatusCode::from(&self).to_string(),
+ "status": status.to_string(),
"detail": self.to_string(),
})),
)
.into_response()
}
}
+
+#[derive(thiserror::Error, Debug)]
+pub enum AuthError {
+ #[error("Invalid email or password")]
+ LoginInvalid,
+
+ #[error("Authorization token not found")]
+ JwtNotFound,
+
+ #[error("The user belonging to this token no longer exists")]
+ UserNotFound,
+
+ #[error("Invalid authorization token")]
+ JwtValidation(#[from] jsonwebtoken::errors::Error),
+
+ #[error("Jwk not found")]
+ JwkNotFound,
+}
+
+impl axum::response::IntoResponse for AuthError {
+ fn into_response(self) -> axum::response::Response {
+ StatusCode::UNAUTHORIZED.into_response()
+ }
+}
diff --git a/src/jwt.rs b/src/jwt.rs
new file mode 100644
index 0000000..6382a01
--- /dev/null
+++ b/src/jwt.rs
@@ -0,0 +1,51 @@
+use std::sync::Arc;
+
+use axum::extract::{Request, State};
+use axum_extra::{
+ headers::{authorization::Bearer, Authorization},
+ TypedHeader,
+};
+use jsonwebtoken::{DecodingKey, Validation};
+use serde::{Deserialize, Serialize};
+use uuid::Uuid;
+
+use crate::{error::AuthError, state::AppState};
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
+pub struct Claims {
+ pub sub: Uuid,
+ pub iat: i64,
+ pub exp: i64,
+}
+
+impl Claims {
+ pub fn new(sub: Uuid, max_age: time::Duration) -> Self {
+ let iat = time::OffsetDateTime::now_utc().unix_timestamp();
+ let exp = iat + max_age.whole_seconds();
+ Self { sub, iat, exp }
+ }
+
+ pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> {
+ jsonwebtoken::encode(
+ &jsonwebtoken::Header::default(),
+ self,
+ &jsonwebtoken::EncodingKey::from_secret(secret),
+ )
+ }
+}
+
+pub async fn authenticate(
+ State(state): State<Arc<AppState>>,
+ TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
+ mut req: Request,
+) -> Result<Request, AuthError> {
+ let claims = jsonwebtoken::decode::<Claims>(
+ bearer.token(),
+ &DecodingKey::from_secret(state.jwt_secret.as_ref()),
+ &Validation::default(),
+ )?
+ .claims;
+
+ req.extensions_mut().insert(claims);
+ Ok(req)
+}
diff --git a/src/lib.rs b/src/lib.rs
index e7502f9..85a4577 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,6 +2,7 @@ pub use error::{Error, Result};
pub use routes::init_router;
pub mod error;
+pub mod jwt;
pub mod model;
pub mod routes;
pub mod state;
diff --git a/src/model.rs b/src/model.rs
index 395cdd1..655456e 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -9,7 +9,7 @@ use crate::Error;
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)]
#[serde(rename_all = "camelCase")]
-pub struct User {
+pub struct UserSchema {
pub uuid: Uuid,
pub name: String,
pub email: String,
@@ -19,21 +19,6 @@ pub struct User {
pub updated_at: Option<OffsetDateTime>,
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct TokenClaims {
- pub sub: Uuid,
- pub exp: i64,
-}
-
-impl TokenClaims {
- pub fn new(sub: Uuid, max_age: time::Duration) -> Self {
- Self {
- sub,
- exp: (time::OffsetDateTime::now_utc() + max_age).unix_timestamp(),
- }
- }
-}
-
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RegisterSchema {
pub name: String,
diff --git a/src/routes.rs b/src/routes.rs
index e2f5587..165dfb6 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -2,11 +2,12 @@ use std::sync::Arc;
use axum::{
http::{StatusCode, Uri},
+ middleware::map_request_with_state,
response::IntoResponse,
};
use axum_extra::routing::RouterExt;
-use crate::state::AppState;
+use crate::{jwt::authenticate, state::AppState};
mod healthcheck;
mod login;
@@ -16,12 +17,13 @@ mod user;
#[tracing::instrument]
pub fn init_router(state: Arc<AppState>) -> axum::Router {
axum::Router::new()
- // .route("/api/user", get(get_user))
+ .typed_get(user::User::get)
+ .typed_get(login::Logout::get)
+ .route_layer(map_request_with_state(state.clone(), authenticate))
.typed_get(healthcheck::HealthCheck::get)
.typed_get(user::UserUuid::get)
.typed_post(register::Register::post)
.typed_post(login::Login::post)
- .typed_get(login::Logout::get)
.fallback(fallback)
.with_state(state)
}
diff --git a/src/routes/login.rs b/src/routes/login.rs
index a580873..67f8422 100644
--- a/src/routes/login.rs
+++ b/src/routes/login.rs
@@ -1,17 +1,14 @@
use std::sync::Arc;
use argon2::{Argon2, PasswordHash, PasswordVerifier};
-use axum::{extract::State, http::header::SET_COOKIE, response::IntoResponse, Json};
-use axum_extra::{
- extract::cookie::{Cookie, SameSite},
- routing::TypedPath,
-};
-use jsonwebtoken::{EncodingKey, Header};
+use axum::{extract::State, response::IntoResponse, Json};
+use axum_extra::{headers::Authorization, routing::TypedPath, TypedHeader};
use serde::Deserialize;
-use serde_json::json;
use crate::{
- model::{LoginSchema, TokenClaims, User},
+ error::AuthError,
+ jwt::Claims,
+ model::{LoginSchema, UserSchema},
state::AppState,
Error,
};
@@ -27,42 +24,27 @@ impl Login {
State(state): State<Arc<AppState>>,
Json(LoginSchema { email, password }): Json<LoginSchema>,
) -> Result<impl IntoResponse, Error> {
- let User {
+ let UserSchema {
uuid,
password_hash,
..
} = sqlx::query_as!(
- User,
+ UserSchema,
"SELECT * FROM users WHERE email = $1",
email.to_ascii_lowercase()
)
.fetch_optional(&state.pool)
.await?
- .ok_or(Error::LoginInvalid)?;
+ .ok_or(AuthError::LoginInvalid)?;
Argon2::default()
.verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?;
- let token = jsonwebtoken::encode(
- &Header::default(),
- &TokenClaims::new(uuid, state.jwt_max_age),
- &EncodingKey::from_secret(state.jwt_secret.as_ref()),
- )?;
-
- let cookie = Cookie::build(("token", token.to_owned()))
- .path("/")
- .max_age(state.jwt_max_age)
- .same_site(SameSite::Lax)
- .http_only(true)
- .build();
-
- let mut response = Json(token).into_response();
+ let token = Claims::new(uuid, state.jwt_max_age).encode(state.jwt_secret.as_ref())?;
- response
- .headers_mut()
- .insert(SET_COOKIE, cookie.to_string().parse()?);
-
- Ok(response)
+ Authorization::bearer(&token)
+ .map(TypedHeader)
+ .map_err(Into::into)
}
}
@@ -72,21 +54,8 @@ pub struct Logout;
impl Logout {
#[tracing::instrument]
- pub async fn get(self) -> Result<impl IntoResponse, Error> {
- let cookie = Cookie::build(("token", ""))
- .path("/")
- .max_age(time::Duration::hours(-1))
- .same_site(SameSite::Lax)
- .http_only(true)
- .build();
-
- let mut response = Json(json!({"status": "success"})).into_response();
-
- response
- .headers_mut()
- .insert(SET_COOKIE, cookie.to_string().parse()?);
-
- Ok(response)
+ pub async fn get(self) -> impl IntoResponse {
+ todo!("Invalidate jwt somehow...");
}
}
@@ -161,25 +130,4 @@ mod tests {
Ok(())
}
-
- #[sqlx::test]
- async fn test_logout(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- jwt_max_age: JWT_MAX_AGE,
- });
- let router = init_router(state.clone());
-
- let request = Request::builder()
- .uri("/api/logout")
- .method("GET")
- .body(Body::empty())?;
-
- let response = router.oneshot(request).await?;
-
- assert_eq!(StatusCode::OK, response.status());
-
- Ok(())
- }
}
diff --git a/src/routes/register.rs b/src/routes/register.rs
index 9a4f007..d2a570c 100644
--- a/src/routes/register.rs
+++ b/src/routes/register.rs
@@ -9,7 +9,7 @@ use axum_extra::routing::TypedPath;
use serde::Deserialize;
use crate::{
- model::{RegisterSchema, User},
+ model::{RegisterSchema, UserSchema},
state::AppState,
Error,
};
@@ -45,7 +45,7 @@ impl Register {
let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?;
let user = sqlx::query_as!(
- User,
+ UserSchema,
"INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *",
name,
email.to_ascii_lowercase(),
@@ -103,7 +103,7 @@ mod tests {
assert_eq!(StatusCode::CREATED, response.status());
let body_bytes = response.into_body().collect().await?.to_bytes();
- let User { name, email, .. } = serde_json::from_slice(&body_bytes)?;
+ let UserSchema { name, email, .. } = serde_json::from_slice(&body_bytes)?;
assert_eq!(user.name, name);
assert_eq!(user.email, email);
diff --git a/src/routes/user.rs b/src/routes/user.rs
index d23f66b..e6e5c3d 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -1,10 +1,10 @@
use std::sync::Arc;
-use axum::{extract::State, response::IntoResponse, Json};
+use axum::{extract::State, response::IntoResponse, Extension, Json};
use axum_extra::routing::TypedPath;
use serde::Deserialize;
-use crate::{model::User, state::AppState, Error};
+use crate::{jwt::Claims, model::UserSchema, state::AppState, Error};
#[derive(Debug, Deserialize, TypedPath)]
#[typed_path("/api/user/:uuid")]
@@ -16,7 +16,26 @@ impl UserUuid {
/// Get a user with a specific `uuid`
#[tracing::instrument]
pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
- sqlx::query_as!(User, "SELECT * FROM users WHERE uuid = $1", self.uuid)
+ sqlx::query_as!(UserSchema, "SELECT * FROM users WHERE uuid = $1", self.uuid)
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or_else(|| Error::UserNotFound)
+ .map(Json)
+ }
+}
+
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/api/user")]
+pub struct User;
+
+impl User {
+ #[tracing::instrument]
+ pub async fn get(
+ self,
+ State(state): State<Arc<AppState>>,
+ Extension(Claims { sub, iat, exp }): Extension<Claims>,
+ ) -> Result<impl IntoResponse, Error> {
+ sqlx::query_as!(UserSchema, "SELECT * FROM users WHERE uuid = $1", sub)
.fetch_optional(&state.pool)
.await?
.ok_or_else(|| Error::UserNotFound)
@@ -30,21 +49,23 @@ mod tests {
use axum::{
body::Body,
- http::{Request, StatusCode},
+ http::{header::AUTHORIZATION, Request, StatusCode},
};
+
use http_body_util::BodyExt;
use sqlx::PgPool;
use tower::ServiceExt;
- use crate::init_router;
+ use crate::{init_router, model::UserSchema};
const JWT_SECRET: &str = "test-jwt-secret-token";
const JWT_MAX_AGE: time::Duration = time::Duration::HOUR;
+ const UUID: uuid::Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145");
type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
- #[sqlx::test]
- async fn test_user_not_found(pool: PgPool) -> TestResult {
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_user_uuid_ok(pool: PgPool) -> TestResult {
let state = Arc::new(AppState {
pool,
jwt_secret: JWT_SECRET.to_string(),
@@ -52,8 +73,8 @@ mod tests {
});
let router = init_router(state.clone());
- let user = User {
- uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"),
+ let user = UserSchema {
+ uuid: UUID,
name: "Arthur Dent".to_string(),
email: "adent@earth.sol".to_string(),
..Default::default()
@@ -65,13 +86,22 @@ mod tests {
let response = router.oneshot(request).await?;
- assert_eq!(StatusCode::NOT_FOUND, response.status());
+ assert_eq!(StatusCode::OK, response.status());
+
+ let body_bytes = response.into_body().collect().await?.to_bytes();
+ let UserSchema {
+ uuid, name, email, ..
+ } = serde_json::from_slice(&body_bytes)?;
+
+ assert_eq!(user.uuid, uuid);
+ assert_eq!(user.name, name);
+ assert_eq!(user.email, email);
Ok(())
}
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_user_ok(pool: PgPool) -> TestResult {
+ #[sqlx::test]
+ async fn test_user_uuid_not_found(pool: PgPool) -> TestResult {
let state = Arc::new(AppState {
pool,
jwt_secret: JWT_SECRET.to_string(),
@@ -79,8 +109,8 @@ mod tests {
});
let router = init_router(state.clone());
- let user = User {
- uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"),
+ let user = UserSchema {
+ uuid: UUID,
name: "Arthur Dent".to_string(),
email: "adent@earth.sol".to_string(),
..Default::default()
@@ -92,16 +122,101 @@ mod tests {
let response = router.oneshot(request).await?;
+ assert_eq!(StatusCode::NOT_FOUND, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_user_ok(pool: PgPool) -> TestResult {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let token = Claims::new(UUID, JWT_MAX_AGE).encode(JWT_SECRET.as_ref())?;
+
+ let request = Request::builder()
+ .uri("/api/user")
+ .header(AUTHORIZATION, format!("Bearer {token}"))
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
assert_eq!(StatusCode::OK, response.status());
let body_bytes = response.into_body().collect().await?.to_bytes();
- let User {
+ let UserSchema {
uuid, name, email, ..
} = serde_json::from_slice(&body_bytes)?;
- assert_eq!(user.uuid, uuid);
- assert_eq!(user.name, name);
- assert_eq!(user.email, email);
+ assert_eq!(UUID, uuid);
+ assert_eq!("Arthur Dent", name);
+ assert_eq!("adent@earth.sol", email);
+
+ Ok(())
+ }
+
+ #[sqlx::test]
+ async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let token = Claims::new(UUID, JWT_MAX_AGE).encode("BAD_SECRET".as_ref())?;
+
+ let request = Request::builder()
+ .uri("/api/user")
+ .header(AUTHORIZATION, format!("Bearer {token}"))
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test]
+ async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let request = Request::builder()
+ .uri("/api/user")
+ .header(AUTHORIZATION, "Bearer invalidtoken")
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test]
+ async fn test_user_unauthorized_missing_token(pool: PgPool) -> TestResult {
+ let state = Arc::new(AppState {
+ pool,
+ jwt_secret: JWT_SECRET.to_string(),
+ jwt_max_age: JWT_MAX_AGE,
+ });
+ let router = init_router(state.clone());
+
+ let request = Request::builder().uri("/api/user").body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ assert_eq!(StatusCode::BAD_REQUEST, response.status());
Ok(())
}
diff --git a/src/state.rs b/src/state.rs
index 508aaa4..4531a42 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,6 +1,15 @@
+use std::fmt::Debug;
+
+use axum::{
+ async_trait,
+ extract::{FromRef, FromRequestParts},
+ http::request::Parts,
+};
use sqlx::{Pool, Postgres};
-#[derive(Debug)]
+use crate::Error;
+
+#[derive(Debug, Clone)]
pub struct AppState {
pub pool: Pool<Postgres>,
pub jwt_secret: String,
@@ -16,3 +25,16 @@ impl AppState {
}
}
}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for AppState
+where
+ Self: FromRef<S>,
+ S: Send + Sync + Debug,
+{
+ type Rejection = Error;
+
+ async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
+ Ok(Self::from_ref(state))
+ }
+}