summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock42
-rw-r--r--Cargo.toml1
-rw-r--r--src/routes/jwt.rs122
-rw-r--r--src/routes/user.rs28
4 files changed, 126 insertions, 67 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 7cf36b7..d649cea 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -392,6 +392,25 @@ dependencies = [
]
[[package]]
+name = "env_filter"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
+dependencies = [
+ "log",
+]
+
+[[package]]
+name = "env_logger"
+version = "0.11.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9"
+dependencies = [
+ "env_filter",
+ "log",
+]
+
+[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1831,6 +1850,28 @@ dependencies = [
]
[[package]]
+name = "test-log"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b319995299c65d522680decf80f2c108d85b861d81dfe340a10d16cee29d9e6"
+dependencies = [
+ "env_logger",
+ "test-log-macros",
+ "tracing-subscriber",
+]
+
+[[package]]
+name = "test-log-macros"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c8f546451eaa38373f549093fe9fd05e7d2bade739e2ddf834b9968621d60107"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.52",
+]
+
+[[package]]
name = "thiserror"
version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2159,6 +2200,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
+ "test-log",
"thiserror",
"time",
"tokio",
diff --git a/Cargo.toml b/Cargo.toml
index 5414596..4713970 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,3 +31,4 @@ pgtemp = "0.2.1"
tower = { version = "0.4.13", features = ["util"] }
mime = "0.3.17"
http-body-util = "0.1.1"
+test-log = { version = "0.2.15", features = ["trace"] }
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs
index ccce13e..902b494 100644
--- a/src/routes/jwt.rs
+++ b/src/routes/jwt.rs
@@ -15,9 +15,9 @@ use axum_extra::{
routing::{RouterExt, TypedPath},
TypedHeader,
};
-use jsonwebtoken::{decode, DecodingKey, EncodingKey};
+use jsonwebtoken::{DecodingKey, EncodingKey, TokenData};
use once_cell::sync::Lazy;
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
@@ -35,7 +35,6 @@ static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| {
JwtEnv::new(secret.as_bytes())
});
-#[derive(Clone)]
struct JwtEnv {
encoding: EncodingKey,
decoding: DecodingKey,
@@ -52,9 +51,23 @@ impl JwtEnv {
validation: Default::default(),
}
}
+
+ pub fn encode<T>(&self, claims: &T) -> Result<String, AuthError>
+ where
+ T: Serialize,
+ {
+ jsonwebtoken::encode(&self.header, claims, &self.encoding).map_err(Into::into)
+ }
+
+ pub fn decode<T>(&self, token: &str) -> Result<TokenData<T>, AuthError>
+ where
+ T: DeserializeOwned,
+ {
+ jsonwebtoken::decode(token, &self.decoding, &self.validation).map_err(Into::into)
+ }
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
@@ -72,17 +85,24 @@ impl<const LIFETIME: i64> Claims<LIFETIME> {
jti: uuid::Uuid::new_v4(),
}
}
+}
- pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
- jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding)
+// 1 day in seconds
+const ACCESS: i64 = 86400;
+
+pub type AccessClaims = Claims<ACCESS>;
+
+impl From<RefreshClaims> for AccessClaims {
+ fn from(value: RefreshClaims) -> Self {
+ Claims::new(value.sub)
}
}
-impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
+impl TryFrom<AccessClaims> for Cookie<'_> {
type Error = Error;
- fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
- Ok(Cookie::build(("token", value.encode()?))
+ fn try_from(value: AccessClaims) -> Result<Self, Self::Error> {
+ Ok(Cookie::build(("token", JWT_ENV.encode(&value)?))
.expires(OffsetDateTime::from_unix_timestamp(value.exp)?)
.secure(true)
.http_only(true)
@@ -90,29 +110,17 @@ impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
}
}
-impl<const L: i64> TryFrom<Claims<L>> for HeaderValue {
+impl TryFrom<AccessClaims> for HeaderValue {
type Error = Error;
- fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ fn try_from(value: AccessClaims) -> Result<Self, Self::Error> {
Cookie::try_from(value)?
- .encoded()
.to_string()
.parse()
.map_err(Into::into)
}
}
-// 1 day in seconds
-const ACCESS: i64 = 86400;
-
-pub type AccessClaims = Claims<ACCESS>;
-
-impl From<RefreshClaims> for AccessClaims {
- fn from(value: RefreshClaims) -> Self {
- Claims::new(value.sub)
- }
-}
-
impl IntoResponse for AccessClaims {
fn into_response(self) -> axum::response::Response {
(self, ()).into_response()
@@ -140,18 +148,15 @@ where
{
type Rejection = AuthError;
- async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
- let token = parts
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
+ let jar = parts
.extract::<CookieJar>()
.await
- .map_err(|_| AuthError::JwtNotFound)?
- .get("token")
- .ok_or(AuthError::JwtNotFound)?
- .to_string();
+ .expect("Infallable result was in fact, fallable");
- decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation)
- .map(|d| d.claims)
- .map_err(Into::into)
+ JWT_ENV
+ .decode(jar.get("token").ok_or(AuthError::JwtNotFound)?.value())
+ .map(|t| t.claims)
}
}
@@ -166,11 +171,14 @@ impl RefreshClaims {
}
}
-//impl IntoResponse for RefreshClaims {
-// fn into_response(self) -> axum::response::Response {
-// (self.refresh(), self).into_response()
-// }
-//}
+impl IntoResponse for RefreshClaims {
+ fn into_response(self) -> axum::response::Response {
+ match JWT_ENV.encode(&self) {
+ Ok(token) => token.into_response(),
+ Err(err) => Error::from(err).into_response(),
+ }
+ }
+}
#[async_trait]
impl<S> FromRequestParts<S> for RefreshClaims
@@ -179,15 +187,13 @@ where
{
type Rejection = AuthError;
- async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::JwtNotFound)?;
- decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation)
- .map(|d| d.claims)
- .map_err(Into::into)
+ Ok(JWT_ENV.decode(bearer.token())?.claims)
}
}
@@ -196,7 +202,6 @@ where
pub struct Issue;
impl Issue {
- #[tracing::instrument(skip_all)]
pub async fn get(
self,
State(state): State<AppState>,
@@ -222,7 +227,7 @@ impl Issue {
let claims = Claims::<REFRESH>::new(uuid);
- Ok((claims.refresh(), claims.encode()?))
+ Ok((claims.refresh(), claims))
}
}
@@ -231,7 +236,6 @@ impl Issue {
pub struct Refresh;
impl Refresh {
- #[tracing::instrument(skip_all)]
pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse {
claims.refresh()
}
@@ -254,15 +258,25 @@ mod tests {
tests::{setup_test_env, TestResult},
};
+ #[test]
+ fn test_jwt_encode_decode() -> TestResult {
+ setup_test_env();
+
+ let claims = AccessClaims::new(uuid::Uuid::new_v4());
+ let token = JWT_ENV.encode(&claims)?;
+ let decoded = JWT_ENV.decode(&token)?.claims;
+ assert_eq!(claims, decoded);
+ Ok(())
+ }
+
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
+ async fn test_issue_ok(pool: PgPool) -> TestResult {
setup_test_env();
let state = AppState { pool };
let router = init_router(state.clone());
- let auth = Authorization::basic("adent@earth.sol", "hunter2");
- tracing::debug!(?auth, "Auth");
+ let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
let request = Request::builder()
.uri("/api/auth/issue")
@@ -270,23 +284,23 @@ mod tests {
.header(AUTHORIZATION, auth.0.encode())
.body(Body::empty())?;
- let response = router.oneshot(dbg!(request)).await?;
-
- tracing::error!(?response);
+ let response = router.oneshot(request).await?;
+ println!("{response:?}");
- assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+ assert_eq!(StatusCode::OK, response.status());
Ok(())
}
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_ok(pool: PgPool) -> TestResult {
+ async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
setup_test_env();
let state = AppState { pool };
let router = init_router(state.clone());
- let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
+ let auth = Authorization::basic("adent@earth.sol", "hunter2");
+ tracing::debug!(?auth, "Auth");
let request = Request::builder()
.uri("/api/auth/issue")
@@ -298,7 +312,7 @@ mod tests {
tracing::error!(?response);
- assert_eq!(StatusCode::OK, response.status());
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
Ok(())
}
diff --git a/src/routes/user.rs b/src/routes/user.rs
index 31cd5cb..d6dd0da 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -1,10 +1,10 @@
-use axum::{extract::State, response::IntoResponse, Extension, Json};
+use axum::{extract::State, response::IntoResponse, Json};
use axum_extra::routing::TypedPath;
use serde::Deserialize;
use crate::{model::UserSchema, state::AppState, Error};
-use super::jwt::Claims;
+use super::jwt::AccessClaims;
#[derive(Debug, Deserialize, TypedPath)]
#[typed_path("/api/user/:uuid")]
@@ -37,7 +37,7 @@ impl User {
pub async fn get(
self,
State(state): State<AppState>,
- Extension(Claims { sub, .. }): Extension<Claims>,
+ AccessClaims { sub, .. }: AccessClaims,
) -> Result<impl IntoResponse, Error> {
sqlx::query_as!(
UserSchema,
@@ -73,8 +73,6 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_uuid_ok(pool: PgPool) -> TestResult {
- std::env::set_var("JWT_SECRET", JWT_SECRET);
-
let state = AppState { pool };
let router = init_router(state.clone());
@@ -130,19 +128,23 @@ mod tests {
Ok(())
}
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ #[test_log::test(sqlx::test(fixtures(path = "../../fixtures", scripts("users"))))]
async fn test_user_ok(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
let router = init_router(state.clone());
+ let user = UserSchema {
+ uuid: UUID,
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ ..Default::default()
+ };
+
let request = Request::builder()
.uri("/api/user")
- .header(
- COOKIE,
- HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
- )
+ .header(COOKIE, HeaderValue::try_from(AccessClaims::new(user.uuid))?)
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -162,7 +164,7 @@ mod tests {
}
#[sqlx::test]
- async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult {
+ async fn test_user_unauthorized_invalid_token_signature(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
@@ -184,7 +186,7 @@ mod tests {
}
#[sqlx::test]
- async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult {
+ async fn test_user_unauthorized_invalid_token_format(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
@@ -213,7 +215,7 @@ mod tests {
let response = router.oneshot(request).await?;
- assert_eq!(StatusCode::BAD_REQUEST, response.status());
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
Ok(())
}