diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/api/account.rs | 11 | ||||
-rw-r--r-- | src/api/users.rs | 53 | ||||
-rw-r--r-- | src/auth/claims.rs | 29 | ||||
-rw-r--r-- | src/auth/error.rs | 3 | ||||
-rw-r--r-- | src/auth/jwt.rs | 8 |
5 files changed, 55 insertions, 49 deletions
diff --git a/src/api/account.rs b/src/api/account.rs index 598d172..bae7c54 100644 --- a/src/api/account.rs +++ b/src/api/account.rs @@ -7,7 +7,7 @@ use axum::{ }; use axum_extra::{ either::Either, - extract::{cookie::Cookie, CookieJar}, + extract::CookieJar, headers::{authorization::Basic, Authorization}, TypedHeader, }; @@ -37,7 +37,7 @@ pub async fn login( } pub async fn logout(claims: AccessClaims, jar: CookieJar) -> Result<CookieJar, Error> { - Ok(jar.remove(Cookie::try_from(claims)?)) + Ok(jar.remove(("token", crate::auth::jwt::JWT.encode(&claims)?))) } #[derive(Debug, Clone, PartialEq, Eq)] @@ -74,7 +74,7 @@ mod tests { body::Body, http::{ header::{AUTHORIZATION, COOKIE, SET_COOKIE}, - HeaderValue, Request, StatusCode, + Request, StatusCode, }, Router, }; @@ -188,7 +188,10 @@ mod tests { let request = Request::builder() .uri("/logout") .method("GET") - .header(COOKIE, HeaderValue::try_from(AccessClaims::issue(USER_ID))?) + .header( + COOKIE, + AccessClaims::issue(USER_ID).as_cookie()?.to_string(), + ) .body(Body::empty())?; let (mut parts, _) = router.oneshot(request).await?.into_parts(); diff --git a/src/api/users.rs b/src/api/users.rs index e73e229..e07bf7e 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -81,13 +81,9 @@ pub async fn create( pub async fn show( Path(uuid): Path<Uuid>, State(pool): State<PgPool>, - AccessClaims { sub, .. }: AccessClaims, + _: AccessClaims, ) -> Result<impl IntoResponse, Error> { - if uuid != sub { - return Err(Error::InvalidToken); - } - - sqlx::query_as!(User, "SELECT * FROM user_ WHERE id = $1 LIMIT 1", sub) + sqlx::query_as!(User, "SELECT * FROM user_ WHERE id = $1 LIMIT 1", uuid) .fetch_optional(&pool) .await? .ok_or_else(|| Error::UserNotFound) @@ -102,7 +98,7 @@ mod tests { body::Body, http::{ header::{CONTENT_TYPE, COOKIE}, - HeaderValue, Request, StatusCode, + Request, StatusCode, }, Router, }; @@ -121,14 +117,17 @@ mod tests { const USER_PASSWORD: &str = "solongandthanksforallthefish"; #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_uuid_ok(pool: PgPool) -> TestResult { + async fn test_get_ok_self(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let request = Request::builder() .uri(format!("/users/{}", USER_ID)) - .header(COOKIE, HeaderValue::try_from(AccessClaims::issue(USER_ID))?) + .header( + COOKIE, + AccessClaims::issue(USER_ID).as_cookie()?.to_string(), + ) .body(Body::empty())?; let response = router.oneshot(request).await?; @@ -147,26 +146,40 @@ mod tests { Ok(()) } - #[sqlx::test] - async fn test_uuid_not_found(pool: PgPool) -> TestResult { + #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] + async fn test_get_ok_other(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); let request = Request::builder() .uri(format!("/users/{}", USER_ID)) - .header(COOKIE, HeaderValue::try_from(AccessClaims::issue(USER_ID))?) + .header( + COOKIE, + AccessClaims::issue(uuid::Uuid::new_v4()) + .as_cookie()? + .to_string(), + ) .body(Body::empty())?; let response = router.oneshot(request).await?; - assert_eq!(StatusCode::NOT_FOUND, response.status()); + assert_eq!(StatusCode::OK, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let User { + id, name, email, .. + } = serde_json::from_slice(&body_bytes)?; + + assert_eq!(USER_ID, id); + assert_eq!(USER_NAME, name); + assert_eq!(USER_EMAIL, email); Ok(()) } #[sqlx::test] - async fn test_unauthorized_invalid_token_signature(pool: PgPool) -> TestResult { + async fn test_get_not_found(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -175,19 +188,19 @@ mod tests { .uri(format!("/users/{}", USER_ID)) .header( COOKIE, - HeaderValue::try_from(AccessClaims::issue(uuid::Uuid::new_v4()))?, + AccessClaims::issue(USER_ID).as_cookie()?.to_string(), ) .body(Body::empty())?; let response = router.oneshot(request).await?; - assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + assert_eq!(StatusCode::NOT_FOUND, response.status()); Ok(()) } #[sqlx::test] - async fn test_unauthorized_invalid_token_format(pool: PgPool) -> TestResult { + async fn test_get_unauthorized_invalid_token_format(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -205,7 +218,7 @@ mod tests { } #[sqlx::test] - async fn test_unauthorized_missing_token(pool: PgPool) -> TestResult { + async fn test_get_unauthorized_missing_token(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -222,7 +235,7 @@ mod tests { } #[sqlx::test] - async fn test_create_created(pool: PgPool) -> TestResult { + async fn test_post_created(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); @@ -253,7 +266,7 @@ mod tests { } #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))] - async fn test_create_conflict(pool: PgPool) -> TestResult { + async fn test_post_conflict(pool: PgPool) -> TestResult { setup_test_env(); let router = Router::new().merge(router()).with_state(AppState { pool }); diff --git a/src/auth/claims.rs b/src/auth/claims.rs index 67c4fbb..652a47f 100644 --- a/src/auth/claims.rs +++ b/src/auth/claims.rs @@ -106,18 +106,10 @@ const ACCESS: i64 = 86400; pub type AccessClaims = Claims<ACCESS>; -impl From<RefreshClaims> for AccessClaims { - fn from(value: RefreshClaims) -> Self { - Claims::issue(value.sub) - } -} - -impl TryFrom<AccessClaims> for Cookie<'_> { - type Error = Error; - - fn try_from(value: AccessClaims) -> Result<Self, Self::Error> { - Ok(Cookie::build(("token", JWT.encode(&value)?)) - .expires(value.exp) +impl AccessClaims { + pub fn as_cookie(&self) -> Result<Cookie, Error> { + Ok(Cookie::build(("token", JWT.encode(&self)?)) + .expires(self.exp) .secure(true) .http_only(true) .path("/api") @@ -125,14 +117,9 @@ impl TryFrom<AccessClaims> for Cookie<'_> { } } -impl TryFrom<AccessClaims> for HeaderValue { - type Error = Error; - - fn try_from(value: AccessClaims) -> Result<Self, Self::Error> { - Cookie::try_from(value)? - .to_string() - .parse() - .map_err(Into::into) +impl From<RefreshClaims> for AccessClaims { + fn from(value: RefreshClaims) -> Self { + Claims::issue(value.sub) } } @@ -150,7 +137,7 @@ impl IntoResponseParts for AccessClaims { mut res: axum::response::ResponseParts, ) -> Result<axum::response::ResponseParts, Self::Error> { res.headers_mut() - .append(SET_COOKIE, HeaderValue::try_from(self)?); + .try_append(SET_COOKIE, self.as_cookie()?.to_string().parse()?)?; Ok(res) } diff --git a/src/auth/error.rs b/src/auth/error.rs index 3a111ca..91aec5c 100644 --- a/src/auth/error.rs +++ b/src/auth/error.rs @@ -9,6 +9,9 @@ pub enum Error { #[error("Failed to parse header: {0} (wrong token type?)")] HeaderRejection(axum_extra::typed_header::TypedHeaderRejection), + #[error("Failed to append header: {0}")] + HeaderMaxSizeReached(#[from] axum::http::header::MaxSizeReached), + #[error("Database error: {0}")] Sqlx(#[from] sqlx::Error), diff --git a/src/auth/jwt.rs b/src/auth/jwt.rs index 0d7b593..f44b7d4 100644 --- a/src/auth/jwt.rs +++ b/src/auth/jwt.rs @@ -4,19 +4,19 @@ use serde::{de::DeserializeOwned, Serialize}; use super::Error; -pub static JWT: Lazy<Jwt> = Lazy::new(|| { +pub static JWT: Lazy<JwtTranscoder> = Lazy::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); - Jwt::new(secret.as_bytes()) + JwtTranscoder::new(secret.as_bytes()) }); -pub struct Jwt { +pub struct JwtTranscoder { encoding: EncodingKey, decoding: DecodingKey, header: jsonwebtoken::Header, validation: jsonwebtoken::Validation, } -impl Jwt { +impl JwtTranscoder { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), |