use axum::{ async_trait, extract::FromRequestParts, http::{ header::{AUTHORIZATION, SET_COOKIE}, request::Parts, HeaderValue, }, response::{IntoResponse, IntoResponseParts}, RequestPartsExt, }; use axum_extra::{ extract::{cookie::Cookie, CookieJar}, headers::{authorization::Bearer, Authorization}, TypedHeader, }; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::{Duration, OffsetDateTime}; use uuid::Uuid; use super::{Error, JWT}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(remote = "Self")] pub struct Claims { pub sub: Uuid, #[serde(with = "numeric_date")] iat: OffsetDateTime, #[serde(with = "numeric_date")] exp: OffsetDateTime, jti: Uuid, } impl Claims { pub fn new(sub: Uuid, iat: OffsetDateTime) -> Self { let iat = iat .replace_millisecond(0) .expect("Failed to remove millisecond from datetime. This should not have happened."); let exp = iat + Duration::new(LIFETIME, 0); let jti = uuid::Uuid::new_v4(); Self { sub, iat, exp, jti } } pub fn issue(uuid: Uuid) -> Self { Self::new(uuid, OffsetDateTime::now_utc()) } pub fn expired(&self) -> bool { self.exp > OffsetDateTime::now_utc() } } impl Serialize for Claims { fn serialize(&self, serializer: S) -> Result where S: Serializer, { Self::serialize(self, serializer) } } impl<'de, const LIFETIME: i64> Deserialize<'de> for Claims { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let claims = Self::deserialize(deserializer)?; if claims.exp - claims.iat != Duration::new(LIFETIME, 0) { return Err(serde::de::Error::custom( "Lifetime is invalid for Claim type", )); } Ok(claims) } } mod numeric_date { //! Custom serialization of OffsetDateTime to conform with the JWT spec (RFC 7519 section 2, "Numeric Date") use serde::{self, Deserialize, Deserializer, Serializer}; use time::OffsetDateTime; /// Serializes an OffsetDateTime to a Unix timestamp (milliseconds since 1970/1/1T00:00:00T) pub fn serialize(date: &OffsetDateTime, serializer: S) -> Result where S: Serializer, { serializer.serialize_i64(date.unix_timestamp()) } /// Attempts to deserialize an i64 and use as a Unix timestamp pub fn deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { OffsetDateTime::from_unix_timestamp(i64::deserialize(deserializer)?) .map_err(|_| serde::de::Error::custom("invalid Unix timestamp value")) } } // 1 day in seconds const ACCESS: i64 = 86400; pub type AccessClaims = Claims; impl From for AccessClaims { fn from(value: RefreshClaims) -> Self { Claims::issue(value.sub) } } impl TryFrom for Cookie<'_> { type Error = Error; fn try_from(value: AccessClaims) -> Result { Ok(Cookie::build(("token", JWT.encode(&value)?)) .expires(value.exp) .secure(true) .http_only(true) .build()) } } impl TryFrom for HeaderValue { type Error = Error; fn try_from(value: AccessClaims) -> Result { Cookie::try_from(value)? .to_string() .parse() .map_err(Into::into) } } impl IntoResponse for AccessClaims { fn into_response(self) -> axum::response::Response { (self, ()).into_response() } } impl IntoResponseParts for AccessClaims { type Error = Error; fn into_response_parts( self, mut res: axum::response::ResponseParts, ) -> Result { res.headers_mut() .append(SET_COOKIE, HeaderValue::try_from(self)?); Ok(res) } } #[async_trait] impl FromRequestParts for AccessClaims where S: Send + Sync, { type Rejection = Error; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { let jar = parts .extract::() .await .expect("Infallable result was in fact, fallable"); JWT.decode(jar.get("token").ok_or(Error::JwtNotFound)?.value()) .map(|t| t.claims) } } // 30 days in seconds const REFRESH: i64 = 2_592_000; pub type RefreshClaims = Claims; impl RefreshClaims { pub fn refresh(self) -> AccessClaims { self.into() } } impl IntoResponseParts for RefreshClaims { type Error = Error; fn into_response_parts( self, mut res: axum::response::ResponseParts, ) -> Result { res.headers_mut().append( AUTHORIZATION, HeaderValue::try_from(format!("Bearer {}", JWT.encode(&self)?))?, ); Ok(res) } } impl IntoResponse for RefreshClaims { fn into_response(self) -> axum::response::Response { JWT.encode(&self).into_response() } } #[async_trait] impl FromRequestParts for RefreshClaims where S: Send + Sync, { type Rejection = Error; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| Error::JwtNotFound)?; JWT.decode(bearer.token()).map(|jwt| jwt.claims) } }