summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-17 13:01:21 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-17 13:01:21 -0500
commit7ca3d4df5cfb05b750a014c87d3f11ad32406316 (patch)
treee63ae261faa659aef4985e8b6dd2677db55431e8
parenta2860a1294b250402114fa016c4639881abc2172 (diff)
refactor: extract PgPool from AppState via AsRef
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml2
-rw-r--r--src/api/account.rs6
-rw-r--r--src/api/users.rs17
-rw-r--r--src/auth.rs13
-rw-r--r--src/auth/credentials.rs9
-rw-r--r--src/state.rs7
7 files changed, 30 insertions, 25 deletions
diff --git a/Cargo.lock b/Cargo.lock
index a51f5a6..c33e307 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -97,6 +97,7 @@ checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e"
dependencies = [
"async-trait",
"axum-core",
+ "axum-macros",
"bytes",
"futures-util",
"http",
diff --git a/Cargo.toml b/Cargo.toml
index 9ee0b70..37ef445 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
anyhow = "1.0.82"
argon2 = { version = "0.5.3", features = ["std"] }
-axum = "0.7.4"
+axum = { version = "0.7.4", features = ["macros"] }
axum-extra = { version = "0.9.3", features = ["typed-routing", "cookie", "typed-header"] }
dotenvy = "0.15.7"
email_address = "0.2.4"
diff --git a/src/api/account.rs b/src/api/account.rs
index 0087df7..598d172 100644
--- a/src/api/account.rs
+++ b/src/api/account.rs
@@ -11,6 +11,7 @@ use axum_extra::{
headers::{authorization::Basic, Authorization},
TypedHeader,
};
+use sqlx::PgPool;
use crate::{
auth::{AccessClaims, Account, RefreshClaims},
@@ -26,12 +27,12 @@ pub fn router() -> Router<AppState> {
}
pub async fn login(
- State(state): State<AppState>,
+ State(pool): State<PgPool>,
auth: Either<RefreshClaims, Login>,
) -> Result<(AccessClaims, RefreshClaims), crate::auth::error::Error> {
match auth {
Either::E1(token) => Ok((token.refresh(), token)),
- Either::E2(Login(account)) => crate::auth::issue(State(state.clone()), account).await,
+ Either::E2(Login(account)) => crate::auth::issue(State(pool), account).await,
}
}
@@ -80,7 +81,6 @@ mod tests {
use axum_extra::headers::{authorization::Credentials, Authorization};
use http_body_util::BodyExt;
- use sqlx::PgPool;
use tower::ServiceExt;
use uuid::Uuid;
diff --git a/src/api/users.rs b/src/api/users.rs
index 6ac0bb8..e73e229 100644
--- a/src/api/users.rs
+++ b/src/api/users.rs
@@ -7,7 +7,7 @@ use axum::{
};
use axum_extra::routing::Resource;
use serde::{Deserialize, Serialize};
-use sqlx::FromRow;
+use sqlx::PgPool;
use time::OffsetDateTime;
use uuid::Uuid;
@@ -22,7 +22,7 @@ pub fn router() -> Resource<AppState> {
Resource::named("users").create(create).show(show)
}
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, FromRow)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct User {
pub id: Uuid,
@@ -40,7 +40,7 @@ pub struct Registration {
}
pub async fn create(
- State(state): State<AppState>,
+ State(pool): State<PgPool>,
Json(Registration {
name,
email,
@@ -53,7 +53,7 @@ pub async fn create(
"SELECT EXISTS(SELECT 1 FROM user_ WHERE email = $1 LIMIT 1)",
email.to_ascii_lowercase()
)
- .fetch_one(&state.pool)
+ .fetch_one(&pool)
.await?;
if exists.is_some_and(|b| b) {
@@ -62,7 +62,7 @@ pub async fn create(
// TODO: Move this into a micro service, possibly behind a feature flag.
let (status, (access, refresh)) =
- crate::auth::credentials::create(State(state.clone()), Json(Credential { password }))
+ crate::auth::credentials::create(State(pool.clone()), Json(Credential { password }))
.await?;
let user = sqlx::query_as!(
@@ -72,7 +72,7 @@ pub async fn create(
name,
email.to_ascii_lowercase(),
)
- .fetch_one(&state.pool)
+ .fetch_one(&pool)
.await?;
Ok((status, access, refresh, Json(user)))
@@ -80,7 +80,7 @@ pub async fn create(
pub async fn show(
Path(uuid): Path<Uuid>,
- State(state): State<AppState>,
+ State(pool): State<PgPool>,
AccessClaims { sub, .. }: AccessClaims,
) -> Result<impl IntoResponse, Error> {
if uuid != sub {
@@ -88,7 +88,7 @@ pub async fn show(
}
sqlx::query_as!(User, "SELECT * FROM user_ WHERE id = $1 LIMIT 1", sub)
- .fetch_optional(&state.pool)
+ .fetch_optional(&pool)
.await?
.ok_or_else(|| Error::UserNotFound)
.map(Json)
@@ -108,7 +108,6 @@ mod tests {
};
use http_body_util::BodyExt;
- use sqlx::PgPool;
use tower::ServiceExt;
use crate::{
diff --git a/src/auth.rs b/src/auth.rs
index 6bf0ddf..909534e 100644
--- a/src/auth.rs
+++ b/src/auth.rs
@@ -9,6 +9,7 @@ use axum_extra::{
headers::{authorization::Basic, Authorization},
TypedHeader,
};
+use sqlx::PgPool;
use uuid::Uuid;
use crate::state::AppState;
@@ -32,11 +33,11 @@ pub fn router() -> axum::Router<AppState> {
}
pub async fn issue(
- State(state): State<AppState>,
+ State(pool): State<PgPool>,
Account { id, password }: Account,
) -> Result<(AccessClaims, RefreshClaims), Error> {
let p: String = sqlx::query_scalar!("SELECT password_hash FROM credential WHERE id = $1", id)
- .fetch_optional(&state.pool)
+ .fetch_optional(&pool)
.await?
.ok_or(Error::LoginInvalid)?;
@@ -58,10 +59,13 @@ pub struct Account {
}
#[async_trait]
-impl FromRequestParts<AppState> for Account {
+impl<S> FromRequestParts<S> for Account
+where
+ S: Send + Sync,
+{
type Rejection = Error;
- async fn from_request_parts(parts: &mut Parts, _: &AppState) -> Result<Self, Self::Rejection> {
+ async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(basic)) =
parts.extract::<TypedHeader<Authorization<Basic>>>().await?;
@@ -82,7 +86,6 @@ mod tests {
Router,
};
use axum_extra::headers::{authorization::Credentials, Authorization};
- use sqlx::PgPool;
use tower::ServiceExt;
use crate::tests::{setup_test_env, TestResult};
diff --git a/src/auth/credentials.rs b/src/auth/credentials.rs
index 88253b3..2ba3f29 100644
--- a/src/auth/credentials.rs
+++ b/src/auth/credentials.rs
@@ -9,6 +9,7 @@ use axum::{
};
use axum_extra::routing::Resource;
use serde::{Deserialize, Serialize};
+use sqlx::PgPool;
use uuid::Uuid;
use crate::state::AppState;
@@ -28,7 +29,7 @@ pub fn router() -> Resource<AppState> {
}
pub async fn create(
- State(state): State<AppState>,
+ State(pool): State<PgPool>,
Json(Credential { password }): Json<Credential>,
) -> Result<(StatusCode, (AccessClaims, RefreshClaims)), Error> {
let salt = SaltString::generate(&mut OsRng);
@@ -38,7 +39,7 @@ pub async fn create(
"INSERT INTO credential (password_hash) VALUES ($1) RETURNING id",
password_hash.to_string()
)
- .fetch_optional(&state.pool)
+ .fetch_optional(&pool)
.await?
.ok_or(Error::Registration)?
.id;
@@ -49,8 +50,8 @@ pub async fn create(
Ok((StatusCode::CREATED, (access, refresh)))
}
-pub async fn destroy(State(state): State<AppState>, Path(uuid): Path<Uuid>) -> Result<(), Error> {
- let mut tx = state.pool.begin().await?;
+pub async fn destroy(State(pool): State<PgPool>, Path(uuid): Path<Uuid>) -> Result<(), Error> {
+ let mut tx = pool.begin().await?;
let rows = sqlx::query!("DELETE FROM credential WHERE id = $1", uuid)
.execute(&mut *tx)
.await?
diff --git a/src/state.rs b/src/state.rs
index 75c1e11..771647d 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,12 +1,13 @@
use std::fmt::Debug;
-use sqlx::{Pool, Postgres};
+use axum::extract::FromRef;
+use sqlx::PgPool;
use crate::Error;
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, FromRef)]
pub struct AppState {
- pub pool: Pool<Postgres>,
+ pub pool: PgPool,
}
impl AppState {