summaryrefslogtreecommitdiffstats
path: root/src/routes
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-10 20:23:14 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-11 23:51:06 -0500
commit8c56000a3090e0843a1f218a00c3503767658e83 (patch)
treebbcbf4ba4d10468ed8a6e891035ffa4646b77a7c /src/routes
parenteb8a597d310d8948d0b5a02911dd2002f00cfb39 (diff)
wip: more work on jwt handling
Diffstat (limited to 'src/routes')
-rw-r--r--src/routes/jwt.rs339
-rw-r--r--src/routes/login.rs131
-rw-r--r--src/routes/register.rs27
-rw-r--r--src/routes/user.rs68
4 files changed, 308 insertions, 257 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs
index 6a229a3..ccce13e 100644
--- a/src/routes/jwt.rs
+++ b/src/routes/jwt.rs
@@ -1,114 +1,305 @@
-use std::sync::Arc;
-
+use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::{
- extract::{Request, State},
- response::IntoResponse,
+ async_trait,
+ extract::{FromRequestParts, State},
+ http::{header::SET_COOKIE, request::Parts, HeaderValue},
+ response::{IntoResponse, IntoResponseParts},
+ RequestPartsExt,
};
use axum_extra::{
extract::{cookie::Cookie, CookieJar},
- headers::{authorization::Bearer, Authorization},
- routing::TypedPath,
+ headers::{
+ authorization::{Basic, Bearer},
+ Authorization,
+ },
+ routing::{RouterExt, TypedPath},
TypedHeader,
};
-use jsonwebtoken::{DecodingKey, Validation};
+use jsonwebtoken::{decode, DecodingKey, EncodingKey};
+use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
-use crate::{error::AuthError, state::AppState, Error};
+use crate::{error::AuthError, model::UserSchema, state::AppState, Error};
+
+pub fn init_router(state: AppState) -> axum::Router<AppState> {
+ axum::Router::new()
+ .typed_get(Issue::get)
+ .typed_get(Refresh::get)
+ .with_state(state)
+}
+
+static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| {
+ let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
+ JwtEnv::new(secret.as_bytes())
+});
+
+#[derive(Clone)]
+struct JwtEnv {
+ encoding: EncodingKey,
+ decoding: DecodingKey,
+ header: jsonwebtoken::Header,
+ validation: jsonwebtoken::Validation,
+}
+
+impl JwtEnv {
+ fn new(secret: &[u8]) -> Self {
+ Self {
+ encoding: EncodingKey::from_secret(secret),
+ decoding: DecodingKey::from_secret(secret),
+ header: Default::default(),
+ validation: Default::default(),
+ }
+ }
+}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct Claims {
+pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
pub exp: i64,
pub jti: Uuid,
}
-impl Claims {
- const MAX_AGE: i64 = 3600;
+impl<const LIFETIME: i64> Claims<LIFETIME> {
+ pub fn new(uuid: Uuid) -> Self {
+ let now = OffsetDateTime::now_utc().unix_timestamp();
+ Self {
+ sub: uuid,
+ iat: now,
+ exp: now + LIFETIME,
+ jti: uuid::Uuid::new_v4(),
+ }
+ }
- pub fn new(sub: Uuid) -> Self {
- let iat = OffsetDateTime::now_utc().unix_timestamp();
- let exp = iat + Self::MAX_AGE;
- let jti = uuid::Uuid::new_v4();
- Self { sub, iat, exp, jti }
+ pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
+ jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding)
}
+}
- pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> {
- jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- self,
- &jsonwebtoken::EncodingKey::from_secret(secret),
- )
+impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Ok(Cookie::build(("token", value.encode()?))
+ .expires(OffsetDateTime::from_unix_timestamp(value.exp)?)
+ .secure(true)
+ .http_only(true)
+ .build())
}
}
-impl From<Uuid> for Claims {
- fn from(value: Uuid) -> Self {
- Self::new(value)
+impl<const L: i64> TryFrom<Claims<L>> for HeaderValue {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Cookie::try_from(value)?
+ .encoded()
+ .to_string()
+ .parse()
+ .map_err(Into::into)
}
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-struct Session {
- jti: Uuid,
- uuid: Uuid,
+// 1 day in seconds
+const ACCESS: i64 = 86400;
+
+pub type AccessClaims = Claims<ACCESS>;
+
+impl From<RefreshClaims> for AccessClaims {
+ fn from(value: RefreshClaims) -> Self {
+ Claims::new(value.sub)
+ }
+}
+
+impl IntoResponse for AccessClaims {
+ fn into_response(self) -> axum::response::Response {
+ (self, ()).into_response()
+ }
+}
+
+impl IntoResponseParts for AccessClaims {
+ type Error = Error;
+
+ fn into_response_parts(
+ self,
+ mut res: axum::response::ResponseParts,
+ ) -> Result<axum::response::ResponseParts, Self::Error> {
+ res.headers_mut()
+ .append(SET_COOKIE, HeaderValue::try_from(self)?);
+
+ Ok(res)
+ }
+}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for AccessClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let token = parts
+ .extract::<CookieJar>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?
+ .get("token")
+ .ok_or(AuthError::JwtNotFound)?
+ .to_string();
+
+ decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
+}
+
+// 30 days in seconds
+const REFRESH: i64 = 2_592_000;
+
+pub type RefreshClaims = Claims<REFRESH>;
+
+impl RefreshClaims {
+ pub fn refresh(self) -> AccessClaims {
+ self.into()
+ }
+}
+
+//impl IntoResponse for RefreshClaims {
+// fn into_response(self) -> axum::response::Response {
+// (self.refresh(), self).into_response()
+// }
+//}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for RefreshClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(bearer)) = parts
+ .extract::<TypedHeader<Authorization<Bearer>>>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?;
+
+ decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
}
#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/auth/refresh")]
-pub struct Refresh;
+#[typed_path("/issue")]
+pub struct Issue;
-impl Refresh {
- #[tracing::instrument]
- pub async fn post(
+impl Issue {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(
self,
- State(state): State<Arc<AppState>>,
- TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
- cookie_jar: CookieJar,
+ State(state): State<AppState>,
+ TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>,
) -> Result<impl IntoResponse, Error> {
- let Claims { sub, .. } = jsonwebtoken::decode::<Claims>(
- bearer.token(),
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- let claims = Claims::from(sub);
-
- let token = jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- &claims,
- &jsonwebtoken::EncodingKey::from_secret(state.jwt_secret.as_ref()),
+ let UserSchema {
+ uuid,
+ password_hash,
+ ..
+ } = sqlx::query_as!(
+ UserSchema,
+ "SELECT * FROM users WHERE email = $1 LIMIT 1",
+ basic.username().to_ascii_lowercase()
+ )
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or(AuthError::LoginInvalid)?;
+
+ Argon2::default().verify_password(
+ basic.password().as_bytes(),
+ &PasswordHash::new(&password_hash)?,
)?;
- let cookie = Cookie::build(("token", token))
- .expires(OffsetDateTime::from_unix_timestamp(claims.exp)?)
- .secure(true)
- .http_only(true);
+ let claims = Claims::<REFRESH>::new(uuid);
- Ok(cookie_jar.add(cookie))
+ Ok((claims.refresh(), claims.encode()?))
}
}
-pub async fn authenticate(
- State(state): State<Arc<AppState>>,
- cookie_jar: CookieJar,
- mut req: Request,
-) -> Result<Request, AuthError> {
- let token = cookie_jar
- .get("token")
- .ok_or(AuthError::JwtNotFound)?
- .to_string();
-
- let claims = jsonwebtoken::decode::<Claims>(
- &token,
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- req.extensions_mut().insert(claims);
- Ok(req)
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/refresh")]
+pub struct Refresh;
+
+impl Refresh {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse {
+ claims.refresh()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use axum::{
+ body::Body,
+ http::{header::AUTHORIZATION, Request, StatusCode},
+ };
+ use axum_extra::headers::authorization::Credentials;
+ use sqlx::PgPool;
+ use tower::ServiceExt;
+
+ use crate::{
+ init_router,
+ tests::{setup_test_env, TestResult},
+ };
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "hunter2");
+ tracing::debug!(?auth, "Auth");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(dbg!(request)).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_login_ok(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::OK, response.status());
+
+ Ok(())
+ }
}
diff --git a/src/routes/login.rs b/src/routes/login.rs
deleted file mode 100644
index 0e1e825..0000000
--- a/src/routes/login.rs
+++ /dev/null
@@ -1,131 +0,0 @@
-use std::sync::Arc;
-
-use argon2::{Argon2, PasswordHash, PasswordVerifier};
-use axum::{extract::State, response::IntoResponse, Json};
-use axum_extra::{headers::Authorization, routing::TypedPath, TypedHeader};
-use serde::Deserialize;
-
-use crate::{
- error::AuthError,
- model::{LoginSchema, UserSchema},
- state::AppState,
- Error,
-};
-
-use super::jwt::Claims;
-
-#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/login")]
-pub struct Login;
-
-impl Login {
- #[tracing::instrument(skip(state, password))]
- pub async fn post(
- self,
- State(state): State<Arc<AppState>>,
- Json(LoginSchema { email, password }): Json<LoginSchema>,
- ) -> Result<impl IntoResponse, Error> {
- let UserSchema {
- uuid,
- password_hash,
- ..
- } = sqlx::query_as!(
- UserSchema,
- "SELECT * FROM users WHERE email = $1 LIMIT 1",
- email.to_ascii_lowercase()
- )
- .fetch_optional(&state.pool)
- .await?
- .ok_or(AuthError::LoginInvalid)?;
-
- Argon2::default()
- .verify_password(password.as_bytes(), &PasswordHash::new(&password_hash)?)?;
-
- let token = Claims::from(uuid).encode(state.jwt_secret.as_ref())?;
-
- Authorization::bearer(&token)
- .map(TypedHeader)
- .map_err(Into::into)
- }
-}
-
-#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/logout")]
-pub struct Logout;
-
-impl Logout {
- #[tracing::instrument]
- pub async fn get(self) -> impl IntoResponse {
- todo!("Invalidate jwt somehow...");
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use axum::{
- body::Body,
- http::{header, Request, StatusCode},
- };
- use sqlx::PgPool;
- use tower::ServiceExt;
-
- use crate::init_router;
-
- const JWT_SECRET: &str = "test-jwt-secret-token";
-
- type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
-
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_unauthorized(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
-
- let user = LoginSchema {
- email: "adent@earth.sol".to_string(),
- password: "hunter2".to_string(),
- };
-
- let request = Request::builder()
- .uri("/api/login")
- .method("POST")
- .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
- .body(Body::from(serde_json::to_vec(&user)?))?;
-
- let response = router.oneshot(request).await?;
-
- assert_eq!(StatusCode::UNAUTHORIZED, response.status());
-
- Ok(())
- }
-
- #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
- async fn test_login_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
-
- let user = LoginSchema {
- email: "adent@earth.sol".to_string(),
- password: "solongandthanksforallthefish".to_string(),
- };
-
- let request = Request::builder()
- .uri("/api/login")
- .method("POST")
- .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
- .body(Body::from(serde_json::to_vec(&user)?))?;
-
- let response = router.oneshot(request).await?;
-
- assert_eq!(StatusCode::OK, response.status());
-
- Ok(())
- }
-}
diff --git a/src/routes/register.rs b/src/routes/register.rs
index 286e70f..75819b0 100644
--- a/src/routes/register.rs
+++ b/src/routes/register.rs
@@ -1,4 +1,4 @@
-use std::{str::FromStr, sync::Arc};
+use std::str::FromStr;
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
@@ -25,7 +25,7 @@ impl Register {
#[tracing::instrument(skip(password))]
pub async fn post(
self,
- State(state): State<Arc<AppState>>,
+ State(state): State<AppState>,
Json(RegisterSchema {
name,
email,
@@ -73,18 +73,16 @@ mod tests {
use sqlx::PgPool;
use tower::ServiceExt;
- use crate::init_router;
-
- const JWT_SECRET: &str = "test-jwt-secret-token";
-
- type TestResult<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
+ use crate::{
+ init_router,
+ tests::{setup_test_env, TestResult},
+ };
#[sqlx::test]
async fn test_register_created(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ setup_test_env();
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = RegisterSchema {
@@ -114,10 +112,9 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_register_conflict(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ setup_test_env();
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = RegisterSchema {
diff --git a/src/routes/user.rs b/src/routes/user.rs
index 73eef04..31cd5cb 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use axum::{extract::State, response::IntoResponse, Extension, Json};
use axum_extra::routing::TypedPath;
use serde::Deserialize;
@@ -17,7 +15,7 @@ pub struct UserUuid {
impl UserUuid {
/// Get a user with a specific `uuid`
#[tracing::instrument]
- pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
+ pub async fn get(self, State(state): State<AppState>) -> impl IntoResponse {
sqlx::query_as!(
UserSchema,
"SELECT * FROM users WHERE uuid = $1 LIMIT 1",
@@ -38,7 +36,7 @@ impl User {
#[tracing::instrument]
pub async fn get(
self,
- State(state): State<Arc<AppState>>,
+ State(state): State<AppState>,
Extension(Claims { sub, .. }): Extension<Claims>,
) -> Result<impl IntoResponse, Error> {
sqlx::query_as!(
@@ -59,14 +57,14 @@ mod tests {
use axum::{
body::Body,
- http::{header::AUTHORIZATION, Request, StatusCode},
+ http::{header::COOKIE, HeaderValue, Request, StatusCode},
};
use http_body_util::BodyExt;
use sqlx::PgPool;
use tower::ServiceExt;
- use crate::{init_router, model::UserSchema};
+ use crate::{init_router, model::UserSchema, routes::jwt::AccessClaims};
const JWT_SECRET: &str = "test-jwt-secret-token";
const UUID: uuid::Uuid = uuid::uuid!("4c14f795-86f0-4361-a02f-0edb966fb145");
@@ -75,10 +73,9 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_uuid_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = UserSchema {
@@ -110,10 +107,9 @@ mod tests {
#[sqlx::test]
async fn test_user_uuid_not_found(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let user = UserSchema {
@@ -136,17 +132,17 @@ mod tests {
#[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
async fn test_user_ok(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
- let token = Claims::from(UUID).encode(JWT_SECRET.as_ref())?;
+ let state = AppState { pool };
+ let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, format!("Bearer {token}"))
+ .header(
+ COOKIE,
+ HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
+ )
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -167,17 +163,17 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_bad_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
- let router = init_router(state.clone());
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
- let token = Claims::from(UUID).encode("BAD_SECRET".as_ref())?;
+ let state = AppState { pool };
+ let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, format!("Bearer {token}"))
+ .header(
+ COOKIE,
+ HeaderValue::try_from(AccessClaims::new(uuid::Uuid::new_v4()))?,
+ )
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -189,15 +185,14 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_invalid_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let request = Request::builder()
.uri("/api/user")
- .header(AUTHORIZATION, "Bearer invalidtoken")
+ .header(COOKIE, "token=sadfasdfsdfs")
.body(Body::empty())?;
let response = router.oneshot(request).await?;
@@ -209,10 +204,9 @@ mod tests {
#[sqlx::test]
async fn test_user_unauthorized_missing_token(pool: PgPool) -> TestResult {
- let state = Arc::new(AppState {
- pool,
- jwt_secret: JWT_SECRET.to_string(),
- });
+ std::env::set_var("JWT_SECRET", JWT_SECRET);
+
+ let state = AppState { pool };
let router = init_router(state.clone());
let request = Request::builder().uri("/api/user").body(Body::empty())?;