summaryrefslogtreecommitdiffstats
path: root/src/routes/jwt.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/routes/jwt.rs')
-rw-r--r--src/routes/jwt.rs339
1 files changed, 265 insertions, 74 deletions
diff --git a/src/routes/jwt.rs b/src/routes/jwt.rs
index 6a229a3..ccce13e 100644
--- a/src/routes/jwt.rs
+++ b/src/routes/jwt.rs
@@ -1,114 +1,305 @@
-use std::sync::Arc;
-
+use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::{
- extract::{Request, State},
- response::IntoResponse,
+ async_trait,
+ extract::{FromRequestParts, State},
+ http::{header::SET_COOKIE, request::Parts, HeaderValue},
+ response::{IntoResponse, IntoResponseParts},
+ RequestPartsExt,
};
use axum_extra::{
extract::{cookie::Cookie, CookieJar},
- headers::{authorization::Bearer, Authorization},
- routing::TypedPath,
+ headers::{
+ authorization::{Basic, Bearer},
+ Authorization,
+ },
+ routing::{RouterExt, TypedPath},
TypedHeader,
};
-use jsonwebtoken::{DecodingKey, Validation};
+use jsonwebtoken::{decode, DecodingKey, EncodingKey};
+use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
-use crate::{error::AuthError, state::AppState, Error};
+use crate::{error::AuthError, model::UserSchema, state::AppState, Error};
+
+pub fn init_router(state: AppState) -> axum::Router<AppState> {
+ axum::Router::new()
+ .typed_get(Issue::get)
+ .typed_get(Refresh::get)
+ .with_state(state)
+}
+
+static JWT_ENV: Lazy<JwtEnv> = Lazy::new(|| {
+ let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
+ JwtEnv::new(secret.as_bytes())
+});
+
+#[derive(Clone)]
+struct JwtEnv {
+ encoding: EncodingKey,
+ decoding: DecodingKey,
+ header: jsonwebtoken::Header,
+ validation: jsonwebtoken::Validation,
+}
+
+impl JwtEnv {
+ fn new(secret: &[u8]) -> Self {
+ Self {
+ encoding: EncodingKey::from_secret(secret),
+ decoding: DecodingKey::from_secret(secret),
+ header: Default::default(),
+ validation: Default::default(),
+ }
+ }
+}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-pub struct Claims {
+pub struct Claims<const LIFETIME: i64 = ACCESS> {
pub sub: Uuid,
pub iat: i64,
pub exp: i64,
pub jti: Uuid,
}
-impl Claims {
- const MAX_AGE: i64 = 3600;
+impl<const LIFETIME: i64> Claims<LIFETIME> {
+ pub fn new(uuid: Uuid) -> Self {
+ let now = OffsetDateTime::now_utc().unix_timestamp();
+ Self {
+ sub: uuid,
+ iat: now,
+ exp: now + LIFETIME,
+ jti: uuid::Uuid::new_v4(),
+ }
+ }
- pub fn new(sub: Uuid) -> Self {
- let iat = OffsetDateTime::now_utc().unix_timestamp();
- let exp = iat + Self::MAX_AGE;
- let jti = uuid::Uuid::new_v4();
- Self { sub, iat, exp, jti }
+ pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
+ jsonwebtoken::encode(&JWT_ENV.header, self, &JWT_ENV.encoding)
}
+}
- pub fn encode(&self, secret: &[u8]) -> Result<String, jsonwebtoken::errors::Error> {
- jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- self,
- &jsonwebtoken::EncodingKey::from_secret(secret),
- )
+impl<const L: i64> TryFrom<Claims<L>> for Cookie<'_> {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Ok(Cookie::build(("token", value.encode()?))
+ .expires(OffsetDateTime::from_unix_timestamp(value.exp)?)
+ .secure(true)
+ .http_only(true)
+ .build())
}
}
-impl From<Uuid> for Claims {
- fn from(value: Uuid) -> Self {
- Self::new(value)
+impl<const L: i64> TryFrom<Claims<L>> for HeaderValue {
+ type Error = Error;
+
+ fn try_from(value: Claims<L>) -> Result<Self, Self::Error> {
+ Cookie::try_from(value)?
+ .encoded()
+ .to_string()
+ .parse()
+ .map_err(Into::into)
}
}
-#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
-struct Session {
- jti: Uuid,
- uuid: Uuid,
+// 1 day in seconds
+const ACCESS: i64 = 86400;
+
+pub type AccessClaims = Claims<ACCESS>;
+
+impl From<RefreshClaims> for AccessClaims {
+ fn from(value: RefreshClaims) -> Self {
+ Claims::new(value.sub)
+ }
+}
+
+impl IntoResponse for AccessClaims {
+ fn into_response(self) -> axum::response::Response {
+ (self, ()).into_response()
+ }
+}
+
+impl IntoResponseParts for AccessClaims {
+ type Error = Error;
+
+ fn into_response_parts(
+ self,
+ mut res: axum::response::ResponseParts,
+ ) -> Result<axum::response::ResponseParts, Self::Error> {
+ res.headers_mut()
+ .append(SET_COOKIE, HeaderValue::try_from(self)?);
+
+ Ok(res)
+ }
+}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for AccessClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let token = parts
+ .extract::<CookieJar>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?
+ .get("token")
+ .ok_or(AuthError::JwtNotFound)?
+ .to_string();
+
+ decode(&token, &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
+}
+
+// 30 days in seconds
+const REFRESH: i64 = 2_592_000;
+
+pub type RefreshClaims = Claims<REFRESH>;
+
+impl RefreshClaims {
+ pub fn refresh(self) -> AccessClaims {
+ self.into()
+ }
+}
+
+//impl IntoResponse for RefreshClaims {
+// fn into_response(self) -> axum::response::Response {
+// (self.refresh(), self).into_response()
+// }
+//}
+
+#[async_trait]
+impl<S> FromRequestParts<S> for RefreshClaims
+where
+ S: Send + Sync,
+{
+ type Rejection = AuthError;
+
+ async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(bearer)) = parts
+ .extract::<TypedHeader<Authorization<Bearer>>>()
+ .await
+ .map_err(|_| AuthError::JwtNotFound)?;
+
+ decode(bearer.token(), &JWT_ENV.decoding, &JWT_ENV.validation)
+ .map(|d| d.claims)
+ .map_err(Into::into)
+ }
}
#[derive(Debug, Deserialize, TypedPath)]
-#[typed_path("/api/auth/refresh")]
-pub struct Refresh;
+#[typed_path("/issue")]
+pub struct Issue;
-impl Refresh {
- #[tracing::instrument]
- pub async fn post(
+impl Issue {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(
self,
- State(state): State<Arc<AppState>>,
- TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
- cookie_jar: CookieJar,
+ State(state): State<AppState>,
+ TypedHeader(Authorization(basic)): TypedHeader<Authorization<Basic>>,
) -> Result<impl IntoResponse, Error> {
- let Claims { sub, .. } = jsonwebtoken::decode::<Claims>(
- bearer.token(),
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- let claims = Claims::from(sub);
-
- let token = jsonwebtoken::encode(
- &jsonwebtoken::Header::default(),
- &claims,
- &jsonwebtoken::EncodingKey::from_secret(state.jwt_secret.as_ref()),
+ let UserSchema {
+ uuid,
+ password_hash,
+ ..
+ } = sqlx::query_as!(
+ UserSchema,
+ "SELECT * FROM users WHERE email = $1 LIMIT 1",
+ basic.username().to_ascii_lowercase()
+ )
+ .fetch_optional(&state.pool)
+ .await?
+ .ok_or(AuthError::LoginInvalid)?;
+
+ Argon2::default().verify_password(
+ basic.password().as_bytes(),
+ &PasswordHash::new(&password_hash)?,
)?;
- let cookie = Cookie::build(("token", token))
- .expires(OffsetDateTime::from_unix_timestamp(claims.exp)?)
- .secure(true)
- .http_only(true);
+ let claims = Claims::<REFRESH>::new(uuid);
- Ok(cookie_jar.add(cookie))
+ Ok((claims.refresh(), claims.encode()?))
}
}
-pub async fn authenticate(
- State(state): State<Arc<AppState>>,
- cookie_jar: CookieJar,
- mut req: Request,
-) -> Result<Request, AuthError> {
- let token = cookie_jar
- .get("token")
- .ok_or(AuthError::JwtNotFound)?
- .to_string();
-
- let claims = jsonwebtoken::decode::<Claims>(
- &token,
- &DecodingKey::from_secret(state.jwt_secret.as_ref()),
- &Validation::default(),
- )?
- .claims;
-
- req.extensions_mut().insert(claims);
- Ok(req)
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/refresh")]
+pub struct Refresh;
+
+impl Refresh {
+ #[tracing::instrument(skip_all)]
+ pub async fn get(self, claims: RefreshClaims) -> impl IntoResponse {
+ claims.refresh()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use axum::{
+ body::Body,
+ http::{header::AUTHORIZATION, Request, StatusCode},
+ };
+ use axum_extra::headers::authorization::Credentials;
+ use sqlx::PgPool;
+ use tower::ServiceExt;
+
+ use crate::{
+ init_router,
+ tests::{setup_test_env, TestResult},
+ };
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_issue_unauthorized(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "hunter2");
+ tracing::debug!(?auth, "Auth");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(dbg!(request)).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::UNAUTHORIZED, response.status());
+
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../../fixtures", scripts("users")))]
+ async fn test_login_ok(pool: PgPool) -> TestResult {
+ setup_test_env();
+
+ let state = AppState { pool };
+ let router = init_router(state.clone());
+
+ let auth = Authorization::basic("adent@earth.sol", "solongandthanksforallthefish");
+
+ let request = Request::builder()
+ .uri("/api/auth/issue")
+ .method("GET")
+ .header(AUTHORIZATION, auth.0.encode())
+ .body(Body::empty())?;
+
+ let response = router.oneshot(request).await?;
+
+ tracing::error!(?response);
+
+ assert_eq!(StatusCode::OK, response.status());
+
+ Ok(())
+ }
}