summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/routes/jwt.rs122
-rw-r--r--src/routes/user.rs28
2 files changed, 83 insertions, 67 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs
index ccce13e..902b494 100644
--- a/src/routes/jwt.rs
+++ b/src/routes/jwt.rs
@@ -15,9 +15,9 @@ use axum_extra::{
routing::{RouterExt, TypedPath},
TypedHeader,
};
-use jsonwebtoken::{decode, DecodingKey, EncodingKey};
+use jsonwebtoken::{DecodingKey, EncodingKey, TokenData};
use once_cell::sync::Lazy;
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
@@ -35,7 +35,6 @@ static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| {
JwtEnv::new(secret.as_bytes())
});
-#[derive(Clone)]
struct JwtEnv {
encoding: EncodingKey,
decoding: DecodingKey,
@@ -52,9 +51,23 @@ impl JwtEnv {
validation: Default::default(),
}
}
+
+ pub fn encode<T>(&self, claims: &T) -> Result<String, AuthError>
+ where
+ T: Serialize,
+ {
+ jsonwebtoken::encode(&self.header, claims, &self.encoding).map_err(Into::into)
+ }
+
+ pub fn decode<T>(&self, token: &str) -> Result<TokenData<T>, AuthError>
+ where
+ T: DeserializeOwned,
+ {
+ jsonwebtoken::decode(token, &self.decoding, &self.validation).map_err(Into::into)
+ }
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
@@ -72,17 +85,24 @@ impl<const LIFETIME: i64> Claims<LIFETIME> {
jti: uuid::Uuid::new_v4(),
}
}
+}
- pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
- jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding)
+// 1 day in seconds
+const ACCESS: i64 = 86400;
+
+pub type AccessClaims = Claims<ACCESS>;
+
+impl From<RefreshClaims> for AccessClaims {
+ fn from(value: RefreshClaims) -> Self {
+ Claims::new(value.sub)
}
}
-impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
+impl TryFrom<AccessClaims> for Cookie<'_> {
type Error = Error;
- fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
- Ok(Cookie::build(("token", value.encode()?))
+ fn try_from(value: AccessClaims) -> Result<Self, Self::Error> {
+ Ok(Cookie::build(("token", JWT_ENV.encode(&value)?))
.expires(OffsetDateTime::from_unix_timestamp(value.exp)?)
.secure(true)
.http_only(true)
@@ -90,29 +110,17 @@ impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
}
}
-impl<const L: i64> TryFrom<Claims<L>> for HeaderValue {
+impl TryFrom<AccessClaims> for HeaderValue {
type Error = Error;
- fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ fn try_from(value: AccessClaims) -> Result<Self, Self::Error> {
Cookie::try_from(value)?
- .encoded()
.to_string()
.parse()
.map_err(Into::into)
}
}
-// 1 day in seconds
-const ACCESS: i64 = 86400;
-
-pub type AccessClaims = Claims<ACCESS>;
-
-impl From<RefreshClaims> for AccessClaims {
- fn from(value: RefreshClaims) -> Self {
- Claims::new(value.sub)
- }
-}
-
impl IntoResponse for AccessClaims {
fn into_response(self) -> axum::response::Response {
(self, ()).into_response()
@@ -140,18 +148,15 @@ where
{
type Rejection = AuthError;
- async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
- let token = parts
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
+ let jar = parts
.extract::<CookieJar>()
.await
- .map_err(|_| AuthError::JwtNotFound)?
- .get("token")
- .ok_or(AuthError::JwtNotFound)?
- .to_string();
+ .expect("Infallable result was in fact, fallable");
- decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation)
- .map(|d| d.claims)
- .map_err(Into::into)
+ JWT_ENV
+ .decode(jar.get("token").ok_or(AuthError::JwtNotFound)?.value())
+ .map(|t| t.claims)
}
}
@@ -166,11 +171,14 @@ impl RefreshClaims {
}
}
-//impl IntoResponse for RefreshClaims {
-// fn into_response(self) -> axum::response::Response {
-// (self.refresh(), self).into_response()
-// }
-//}
+impl IntoResponse for RefreshClaims {
+ fn into_response(self) -> axum::response::Response {
+ match JWT_ENV.encode(&self) {
+ Ok(token) => token.into_response(),
+ Err(err) => Error::from(err).into_response(),
+ }
+ }
+}
#[async_trait]
impl<S> FromRequestParts<S> for RefreshClaims
@@ -179,15 +187,13 @@ where
{
type Rejection = AuthError;
- async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::JwtNotFound)?;
- decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation)
- .map(|d| d.claims)
- .map_err(Into::into)
+ Ok(JWT_ENV.decode(bearer.token())?.claims)
}
}
@@ -196,7 +202,6 @@ where
pub struct Issue;
impl Issue {
- #[tracing::instrument(skip_all)]
pub async fn get(
self,
State(state): State<AppState>,
@@ -222,7 +227,7 @@ impl Issue {
let claims = Claims::<REFRESH>::new(uuid);
- Ok((claims.refresh(), claims.encode()?))
+ Ok((claims.refresh(), claims))
}
}
@@ -231,7 +236,6 @@ impl Issue {
pub struct Refresh;
impl Refresh {
- #[tracing::instrument(skip_all)]
pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse {
claims.refresh()
}
@@ -254,15 +258,25 @@ mod tests {
tests::{setup_test_env, TestResult},
};
+ #[test]
+ fn test_jwt_encode_decode() -> TestResult {
+ setup_test_env();
+
+ let claims = AccessClaims::new(uuid::Uuid::new_v4());
+ let token = JWT_ENV.encode(&claims)?;
+ let decoded = JWT_ENV.decode(&token)?.claims;
+ assert_eq!(claims, decoded);
+ Ok(())
+ }
+
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
+ async fn test_issue_ok(pool: PgPool) -> TestResult {
setup_test_env();
let state = AppState { pool };
let router = init_router(state.clone());
- let auth = Authorization::basic("adent@earth.sol", "hunter2");
- tracing::debug!(?auth, "Auth");
+ let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
let request = Request::builder()
.uri("/api/auth/issue")
@@ -270,23 +284,23 @@ mod tests {
.header(AUTHORIZATION, auth.0.encode())
.body(Body::empty())?;
- let response = router.oneshot(dbg!(request)).await?;
-
- tracing::error!(?response);
+ let response = router.oneshot(request).await?;
+ println!("{response:?}");
- assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+ assert_eq!(StatusCode::OK, response.status());
Ok(())
}
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_ok(pool: PgPool) -> TestResult {
+ async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
setup_test_env();
let state = AppState { pool };
let router = init_router(state.clone());
- let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
+ let auth = Authorization::basic("adent@earth.sol", "hunter2");
+ tracing::debug!(?auth, "Auth");
let request = Request::builder()
.uri("/api/auth/issue")
@@ -298,7 +312,7 @@ mod tests {
tracing::error!(?response);
- assert_eq!(StatusCode::OK, response.status());
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
Ok(())
}
diff --git a/src/routes/user.rs b/src/routes/user.rs
index 31cd5cb..d6dd0da 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -1,10 +1,10 @@
-use axum::{extract::State, response::IntoResponse, Extension, Json};
+use axum::{extract::State, response::IntoResponse, Json};
use axum_extra::routing::TypedPath;
use serde::Deserialize;
use crate::{model::UserSchema, state::AppState, Error};
-use super::jwt::Claims;
+use super::jwt::AccessClaims;
#[derive(Debug, Deserialize, TypedPath)]
#[typed_path("/api/user/:uuid")]
@@ -37,7 +37,7 @@ impl User {
pub async fn get(
self,
State(state): State<AppState>,
- Extension(Claims { sub, .. }): Extension<Claims>,
+ AccessClaims { sub, .. }: AccessClaims,
) -> Result<impl IntoResponse, Error> {
sqlx::query_as!(
UserSchema,
@@ -73,8 +73,6 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_uuid_ok(pool: PgPool) -> TestResult {
- std::env::set_var("JWT_SECRET", JWT_SECRET);
-
let state = AppState { pool };
let router = init_router(state.clone());
@@ -130,19 +128,23 @@ mod tests {
Ok(())
}
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ #[test_log::test(sqlx::test(fixtures(path = "../../fixtures", scripts("users"))))]
async fn test_user_ok(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
let router = init_router(state.clone());
+ let user = UserSchema {
+ uuid: UUID,
+ name: "Arthur Dent".to_string(),
+ email: "adent@earth.sol".to_string(),
+ ..Default::default()
+ };
+
let request = Request::builder()
.uri("/api/user")
- .header(
- COOKIE,
- HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
- )
+ .header(COOKIE, HeaderValue::try_from(AccessClaims::new(user.uuid))?)
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -162,7 +164,7 @@ mod tests {
}
#[sqlx::test]
- async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult {
+ async fn test_user_unauthorized_invalid_token_signature(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
@@ -184,7 +186,7 @@ mod tests {
}
#[sqlx::test]
- async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult {
+ async fn test_user_unauthorized_invalid_token_format(pool: PgPool) -> TestResult {
std::env::set_var("JWT_SECRET", JWT_SECRET);
let state = AppState { pool };
@@ -213,7 +215,7 @@ mod tests {
let response = router.oneshot(request).await?;
- assert_eq!(StatusCode::BAD_REQUEST, response.status());
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
Ok(())
}