diff options
-rw-r--r-- | src/error.rs | 59 | ||||
-rw-r--r-- | src/jwt.rs | 51 | ||||
-rw-r--r-- | src/lib.rs | 1 | ||||
-rw-r--r-- | src/model.rs | 17 | ||||
-rw-r--r-- | src/routes.rs | 8 | ||||
-rw-r--r-- | src/routes/login.rs | 80 | ||||
-rw-r--r-- | src/routes/register.rs | 6 | ||||
-rw-r--r-- | src/routes/user.rs | 151 | ||||
-rw-r--r-- | src/state.rs | 24 |
9 files changed, 272 insertions, 125 deletions
diff --git a/src/error.rs b/src/error.rs index 6a32438..6414a13 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,8 +26,11 @@ pub enum Error { #[error("Json error: {0}")] Json(#[from] serde_json::Error), - #[error("JWT error: {0}")] - JWT(#[from] jsonwebtoken::errors::Error), + #[error("JSON web token error: {0}")] + Jwt(#[from] jsonwebtoken::errors::Error), + + #[error("Token error: {0}")] + Token(#[from] axum_extra::headers::authorization::InvalidBearerToken), #[error("Database error: {0}")] Sqlx(#[from] sqlx::Error), @@ -48,7 +51,7 @@ pub enum Error { EmailInvalid(#[from] email_address::Error), #[error("Invalid email or password")] - LoginInvalid, + Authorization(#[from] AuthError), #[error("{0}")] Other(String), @@ -57,35 +60,55 @@ pub enum Error { impl From<argon2::password_hash::Error> for Error { fn from(value: argon2::password_hash::Error) -> Self { match value { - argon2::password_hash::Error::Password => Self::LoginInvalid, + argon2::password_hash::Error::Password => Self::Authorization(AuthError::LoginInvalid), _ => Self::PasswordHash(value), } } } -impl From<&Error> for StatusCode { - fn from(value: &Error) -> Self { - match value { - Error::UserNotFound => StatusCode::NOT_FOUND, - Error::EmailExists => StatusCode::CONFLICT, - Error::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY, - Error::LoginInvalid => StatusCode::UNAUTHORIZED, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { // TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html) + let status = match &self { + Self::UserNotFound => StatusCode::NOT_FOUND, + Self::EmailExists => StatusCode::CONFLICT, + Self::EmailInvalid(_) => StatusCode::UNPROCESSABLE_ENTITY, + Self::Authorization(_) => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + ( - StatusCode::from(&self), + status, Json(json!({ - "status": StatusCode::from(&self).to_string(), + "status": status.to_string(), "detail": self.to_string(), })), ) .into_response() } } + +#[derive(thiserror::Error, Debug)] +pub enum AuthError { + #[error("Invalid email or password")] + LoginInvalid, + + #[error("Authorization token not found")] + JwtNotFound, + + #[error("The user belonging to this token no longer exists")] + UserNotFound, + + #[error("Invalid authorization token")] + JwtValidation(#[from] jsonwebtoken::errors::Error), + + #[error("Jwk not found")] + JwkNotFound, +} + +impl axum::response::IntoResponse for AuthError { + fn into_response(self) -> axum::response::Response { + StatusCode::UNAUTHORIZED.into_response() + } +} diff --git a/src/jwt.rs b/src/jwt.rs new file mode 100644 index 0000000..6382a01 --- /dev/null +++ b/src/jwt.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use axum::extract::{Request, State}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, +}; +use jsonwebtoken::{DecodingKey, Validation}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::{error::AuthError, state::AppState}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct Claims { + pub sub: Uuid, + pub iat: i64, + pub exp: i64, +} + +impl Claims { + pub fn new(sub: Uuid, max_age: time::Duration) -> Self { + let iat = time::OffsetDateTime::now_utc().unix_timestamp(); + let exp = iat + max_age.whole_seconds(); + Self { sub, iat, exp } + } + + pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> { + jsonwebtoken::encode( + &jsonwebtoken::Header::default(), + self, + &jsonwebtoken::EncodingKey::from_secret(secret), + ) + } +} + +pub async fn authenticate( + State(state): State<Arc<AppState>>, + TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, + mut req: Request, +) -> Result<Request, AuthError> { + let claims = jsonwebtoken::decode::<Claims>( + bearer.token(), + &DecodingKey::from_secret(state.jwt_secret.as_ref()), + &Validation::default(), + )? + .claims; + + req.extensions_mut().insert(claims); + Ok(req) +} @@ -2,6 +2,7 @@ pub use error::{Error, Result}; pub use routes::init_router; pub mod error; +pub mod jwt; pub mod model; pub mod routes; pub mod state; diff --git a/src/model.rs b/src/model.rs index 395cdd1..655456e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -9,7 +9,7 @@ use crate::Error; #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)] #[serde(rename_all = "camelCase")] -pub struct User { +pub struct UserSchema { pub uuid: Uuid, pub name: String, pub email: String, @@ -19,21 +19,6 @@ pub struct User { pub updated_at: Option<OffsetDateTime>, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct TokenClaims { - pub sub: Uuid, - pub exp: i64, -} - -impl TokenClaims { - pub fn new(sub: Uuid, max_age: time::Duration) -> Self { - Self { - sub, - exp: (time::OffsetDateTime::now_utc() + max_age).unix_timestamp(), - } - } -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RegisterSchema { pub name: String, diff --git a/src/routes.rs b/src/routes.rs index e2f5587..165dfb6 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -2,11 +2,12 @@ use std::sync::Arc; use axum::{ http::{StatusCode, Uri}, + middleware::map_request_with_state, response::IntoResponse, }; use axum_extra::routing::RouterExt; -use crate::state::AppState; +use crate::{jwt::authenticate, state::AppState}; mod healthcheck; mod login; @@ -16,12 +17,13 @@ mod user; #[tracing::instrument] pub fn init_router(state: Arc<AppState>) -> axum::Router { axum::Router::new() - // .route("/api/user", get(get_user)) + .typed_get(user::User::get) + .typed_get(login::Logout::get) + .route_layer(map_request_with_state(state.clone(), authenticate)) .typed_get(healthcheck::HealthCheck::get) .typed_get(user::UserUuid::get) .typed_post(register::Register::post) .typed_post(login::Login::post) - .typed_get(login::Logout::get) .fallback(fallback) .with_state(state) } diff --git a/src/routes/login.rs b/src/routes/login.rs index a580873..67f8422 100644 --- a/src/routes/login.rs +++ b/src/routes/login.rs @@ -1,17 +1,14 @@ use std::sync::Arc; use argon2::{Argon2, PasswordHash, PasswordVerifier}; -use axum::{extract::State, http::header::SET_COOKIE, response::IntoResponse, Json}; -use axum_extra::{ - extract::cookie::{Cookie, SameSite}, - routing::TypedPath, -}; -use jsonwebtoken::{EncodingKey, Header}; +use axum::{extract::State, response::IntoResponse, Json}; +use axum_extra::{headers::Authorization, routing::TypedPath, TypedHeader}; use serde::Deserialize; -use serde_json::json; use crate::{ - model::{LoginSchema, TokenClaims, User}, + error::AuthError, + jwt::Claims, + model::{LoginSchema, UserSchema}, state::AppState, Error, }; @@ -27,42 +24,27 @@ impl Login { State(state): State<Arc<AppState>>, Json(LoginSchema { email, password }): Json<LoginSchema>, ) -> Result<impl IntoResponse, Error> { - let User { + let UserSchema { uuid, password_hash, .. } = sqlx::query_as!( - User, + UserSchema, "SELECT * FROM users WHERE email = $1", email.to_ascii_lowercase() ) .fetch_optional(&state.pool) .await? - .ok_or(Error::LoginInvalid)?; + .ok_or(AuthError::LoginInvalid)?; Argon2::default() .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?; - let token = jsonwebtoken::encode( - &Header::default(), - &TokenClaims::new(uuid, state.jwt_max_age), - &EncodingKey::from_secret(state.jwt_secret.as_ref()), - )?; - - let cookie = Cookie::build(("token", token.to_owned())) - .path("/") - .max_age(state.jwt_max_age) - .same_site(SameSite::Lax) - .http_only(true) - .build(); - - let mut response = Json(token).into_response(); + let token = Claims::new(uuid, state.jwt_max_age).encode(state.jwt_secret.as_ref())?; - response - .headers_mut() - .insert(SET_COOKIE, cookie.to_string().parse()?); - - Ok(response) + Authorization::bearer(&token) + .map(TypedHeader) + .map_err(Into::into) } } @@ -72,21 +54,8 @@ pub struct Logout; impl Logout { #[tracing::instrument] - pub async fn get(self) -> Result<impl IntoResponse, Error> { - let cookie = Cookie::build(("token", "")) - .path("/") - .max_age(time::Duration::hours(-1)) - .same_site(SameSite::Lax) - .http_only(true) - .build(); - - let mut response = Json(json!({"status": "success"})).into_response(); - - response - .headers_mut() - .insert(SET_COOKIE, cookie.to_string().parse()?); - - Ok(response) + pub async fn get(self) -> impl IntoResponse { + todo!("Invalidate jwt somehow..."); } } @@ -161,25 +130,4 @@ mod tests { Ok(()) } - - #[sqlx::test] - async fn test_logout(pool: PgPool) -> TestResult { - let state = Arc::new(AppState { - pool, - jwt_secret: JWT_SECRET.to_string(), - jwt_max_age: JWT_MAX_AGE, - }); - let router = init_router(state.clone()); - - let request = Request::builder() - .uri("/api/logout") - .method("GET") - .body(Body::empty())?; - - let response = router.oneshot(request).await?; - - assert_eq!(StatusCode::OK, response.status()); - - Ok(()) - } } diff --git a/src/routes/register.rs b/src/routes/register.rs index 9a4f007..d2a570c 100644 --- a/src/routes/register.rs +++ b/src/routes/register.rs @@ -9,7 +9,7 @@ use axum_extra::routing::TypedPath; use serde::Deserialize; use crate::{ - model::{RegisterSchema, User}, + model::{RegisterSchema, UserSchema}, state::AppState, Error, }; @@ -45,7 +45,7 @@ impl Register { let password_hash = Argon2::default().hash_password(password.as_bytes(), &salt)?; let user = sqlx::query_as!( - User, + UserSchema, "INSERT INTO users (name,email,password_hash) VALUES ($1, $2, $3) RETURNING *", name, email.to_ascii_lowercase(), @@ -103,7 +103,7 @@ mod tests { assert_eq!(StatusCode::CREATED, response.status()); let body_bytes = response.into_body().collect().await?.to_bytes(); - let User { name, email, .. } = serde_json::from_slice(&body_bytes)?; + let UserSchema { name, email, .. } = serde_json::from_slice(&body_bytes)?; assert_eq!(user.name, name); assert_eq!(user.email, email); diff --git a/src/routes/user.rs b/src/routes/user.rs index d23f66b..e6e5c3d 100644 --- a/src/routes/user.rs +++ b/src/routes/user.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Extension, Json}; use axum_extra::routing::TypedPath; use serde::Deserialize; -use crate::{model::User, state::AppState, Error}; +use crate::{jwt::Claims, model::UserSchema, state::AppState, Error}; #[derive(Debug, Deserialize, TypedPath)] #[typed_path("/api/user/:uuid")] @@ -16,7 +16,26 @@ impl UserUuid { /// Get a user with a specific `uuid` #[tracing::instrument] pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse { - sqlx::query_as!(User, "SELECT * FROM users WHERE uuid = $1", self.uuid) + sqlx::query_as!(UserSchema, "SELECT * FROM users WHERE uuid = $1", self.uuid) + .fetch_optional(&state.pool) + .await? + .ok_or_else(|| Error::UserNotFound) + .map(Json) + } +} + +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/api/user")] +pub struct User; + +impl User { + #[tracing::instrument] + pub async fn get( + self, + State(state): State<Arc<AppState>>, + Extension(Claims { sub, iat, exp }): Extension<Claims>, + ) -> Result<impl IntoResponse, Error> { + sqlx::query_as!(UserSchema, "SELECT * FROM users WHERE uuid = $1", sub) .fetch_optional(&state.pool) .await? .ok_or_else(|| Error::UserNotFound) @@ -30,21 +49,23 @@ mod tests { use axum::{ body::Body, - http::{Request, StatusCode}, + http::{header::AUTHORIZATION, Request, StatusCode}, }; + use http_body_util::BodyExt; use sqlx::PgPool; use tower::ServiceExt; - use crate::init_router; + use crate::{init_router, model::UserSchema}; const JWT_SECRET: &str = "test-jwt-secret-token"; const JWT_MAX_AGE: time::Duration = time::Duration::HOUR; + const UUID: uuid::Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"); type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>; - #[sqlx::test] - async fn test_user_not_found(pool: PgPool) -> TestResult { + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_user_uuid_ok(pool: PgPool) -> TestResult { let state = Arc::new(AppState { pool, jwt_secret: JWT_SECRET.to_string(), @@ -52,8 +73,8 @@ mod tests { }); let router = init_router(state.clone()); - let user = User { - uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + let user = UserSchema { + uuid: UUID, name: "Arthur Dent".to_string(), email: "adent@earth.sol".to_string(), ..Default::default() @@ -65,13 +86,22 @@ mod tests { let response = router.oneshot(request).await?; - assert_eq!(StatusCode::NOT_FOUND, response.status()); + assert_eq!(StatusCode::OK, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let UserSchema { + uuid, name, email, .. + } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(user.uuid, uuid); + assert_eq!(user.name, name); + assert_eq!(user.email, email); Ok(()) } - #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_user_ok(pool: PgPool) -> TestResult { + #[sqlx::test] + async fn test_user_uuid_not_found(pool: PgPool) -> TestResult { let state = Arc::new(AppState { pool, jwt_secret: JWT_SECRET.to_string(), @@ -79,8 +109,8 @@ mod tests { }); let router = init_router(state.clone()); - let user = User { - uuid: uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145"), + let user = UserSchema { + uuid: UUID, name: "Arthur Dent".to_string(), email: "adent@earth.sol".to_string(), ..Default::default() @@ -92,16 +122,101 @@ mod tests { let response = router.oneshot(request).await?; + assert_eq!(StatusCode::NOT_FOUND, response.status()); + + Ok(()) + } + + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_user_ok(pool: PgPool) -> TestResult { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let token = Claims::new(UUID, JWT_MAX_AGE).encode(JWT_SECRET.as_ref())?; + + let request = Request::builder() + .uri("/api/user") + .header(AUTHORIZATION, format!("Bearer {token}")) + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + assert_eq!(StatusCode::OK, response.status()); let body_bytes = response.into_body().collect().await?.to_bytes(); - let User { + let UserSchema { uuid, name, email, .. } = serde_json::from_slice(&body_bytes)?; - assert_eq!(user.uuid, uuid); - assert_eq!(user.name, name); - assert_eq!(user.email, email); + assert_eq!(UUID, uuid); + assert_eq!("Arthur Dent", name); + assert_eq!("adent@earth.sol", email); + + Ok(()) + } + + #[sqlx::test] + async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let token = Claims::new(UUID, JWT_MAX_AGE).encode("BAD_SECRET".as_ref())?; + + let request = Request::builder() + .uri("/api/user") + .header(AUTHORIZATION, format!("Bearer {token}")) + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test] + async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let request = Request::builder() + .uri("/api/user") + .header(AUTHORIZATION, "Bearer invalidtoken") + .body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + Ok(()) + } + + #[sqlx::test] + async fn test_user_unauthorized_missing_token(pool: PgPool) -> TestResult { + let state = Arc::new(AppState { + pool, + jwt_secret: JWT_SECRET.to_string(), + jwt_max_age: JWT_MAX_AGE, + }); + let router = init_router(state.clone()); + + let request = Request::builder().uri("/api/user").body(Body::empty())?; + + let response = router.oneshot(request).await?; + + assert_eq!(StatusCode::BAD_REQUEST, response.status()); Ok(()) } diff --git a/src/state.rs b/src/state.rs index 508aaa4..4531a42 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,15 @@ +use std::fmt::Debug; + +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, + http::request::Parts, +}; use sqlx::{Pool, Postgres}; -#[derive(Debug)] +use crate::Error; + +#[derive(Debug, Clone)] pub struct AppState { pub pool: Pool<Postgres>, pub jwt_secret: String, @@ -16,3 +25,16 @@ impl AppState { } } } + +#[async_trait] +impl<S> FromRequestParts<S> for AppState +where + Self: FromRef<S>, + S: Send + Sync + Debug, +{ + type Rejection = Error; + + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { + Ok(Self::from_ref(state)) + } +} |