summaryrefslogtreecommitdiffstats
path: root/src/routes.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/routes.rs')
-rw-r--r--src/routes.rs167
1 files changed, 144 insertions, 23 deletions
diff --git a/src/routes.rs b/src/routes.rs
index 0a81317..0bf34b2 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -1,22 +1,33 @@
use std::sync::Arc;
+use argon2::{
+ password_hash::{rand_core::OsRng, SaltString},
+ Argon2, PasswordHasher,
+};
use axum::{
extract::State,
http::{StatusCode, Uri},
+ response::IntoResponse,
Json,
};
use axum_extra::routing::{RouterExt, TypedPath};
use serde::Deserialize;
-use crate::{model::User, state::AppState, Error};
+use crate::{
+ model::{RegisterSchema, User},
+ state::AppState,
+ Error,
+};
-pub fn router(state: AppState) -> axum::Router {
+#[tracing::instrument]
+pub fn router(state: Arc<AppState>) -> axum::Router {
axum::Router::new()
// .route("/api/user", get(get_user))
.typed_get(HealthCheck::get)
- .typed_get(UserId::get)
+ .typed_get(UserUuid::get)
+ .typed_post(Register::post)
.fallback(fallback)
- .with_state(Arc::new(state))
+ .with_state(state)
}
#[derive(Debug, Deserialize, TypedPath)]
@@ -24,7 +35,8 @@ pub fn router(state: AppState) -> axum::Router {
pub struct HealthCheck;
impl HealthCheck {
- pub async fn get(self) -> Json<serde_json::Value> {
+ #[tracing::instrument]
+ pub async fn get(self) -> impl IntoResponse {
const MESSAGE: &str = "Unnamed server";
let json_response = serde_json::json!({
@@ -38,26 +50,62 @@ impl HealthCheck {
#[derive(Debug, Deserialize, TypedPath)]
#[typed_path("/api/user/:uuid")]
-pub struct UserId {
+pub struct UserUuid {
pub uuid: uuid::Uuid,
}
-impl UserId {
- /// Get a user via their `id`
- #[tracing::instrument(ret, skip(data))]
- pub async fn get(
- self,
- State(data): State<Arc<AppState>>,
- ) -> Result<Json<serde_json::Value>, Error> {
+impl UserUuid {
+ /// Get a user with a specific `uuid`
+ #[tracing::instrument]
+ pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid)
- .fetch_optional(&data.db_pool)
+ .fetch_optional(&state.pool)
.await?
- .ok_or_else(|| Error::UserNotFound(self.uuid))
- .map(User::into_query_response)
+ .ok_or_else(|| Error::UserNotFound)
+ .map(Json)
+ }
+}
+
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/api/user/register")]
+pub struct Register;
+
+impl Register {
+ #[tracing::instrument(skip(register_schema))]
+ pub async fn post(
+ self,
+ State(state): State<Arc<AppState>>,
+ Json(register_schema): Json<RegisterSchema>,
+ ) -> impl IntoResponse {
+ let exists: Option<bool> =
+ sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)")
+ .bind(register_schema.email.to_ascii_lowercase())
+ .fetch_one(&state.pool)
+ .await?;
+
+ if exists.is_some_and(|b| b) {
+ return Err(Error::EmailExists);
+ }
+
+ let salt = SaltString::generate(&mut OsRng);
+ let hashed_password =
+ Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?;
+
+ let user = sqlx::query_as!(
+ User,
+ "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *",
+ register_schema.name,
+ register_schema.email.to_ascii_lowercase(),
+ hashed_password.to_string()
+ )
+ .fetch_one(&state.pool)
+ .await?;
+
+ Ok((StatusCode::CREATED, Json(user)))
}
}
-pub async fn fallback(uri: Uri) -> (StatusCode, String) {
+pub async fn fallback(uri: Uri) -> impl IntoResponse {
(StatusCode::NOT_FOUND, format!("Route not found: {uri}"))
}
@@ -65,15 +113,88 @@ pub async fn fallback(uri: Uri) -> (StatusCode, String) {
mod tests {
use super::*;
- use axum_test::TestServer;
+ use axum::{
+ body::Body,
+ http::{header, Request, StatusCode},
+ };
+ use http_body_util::BodyExt;
+ use sqlx::PgPool;
+ use tower::ServiceExt;
+
+ #[sqlx::test]
+ async fn test_fallback(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri("/does-not-exist")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::NOT_FOUND, response.status());
- #[tokio::test]
- async fn test_fallback() -> Result<(), Box<dyn std::error::Error>> {
- let server = TestServer::new(axum::Router::new().fallback(fallback))?;
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
+ async fn test_user(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1")
+ .fetch_one(&state.pool)
+ .await?;
- let response = server.get("/fallback").await;
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri(format!("/api/user/{}", user.id))
+ .body(Body::empty())?,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::OK, response.status());
+
+ Ok(())
+ }
- assert_eq!(StatusCode::NOT_FOUND, response.status_code());
+ #[sqlx::test]
+ async fn test_user_register(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let register_user = RegisterSchema {
+ name: "Ford Prefect".to_string(),
+ email: "fprefect@heartofgold.galaxy".to_string(),
+ password: "42".to_string(),
+ };
+
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri("/api/user/register")
+ .method("POST")
+ .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(Body::from(
+ serde_json::to_vec(&serde_json::json!(register_user)).unwrap(),
+ ))?,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::CREATED, response.status());
+
+ let body_bytes = response.into_body().collect().await?.to_bytes();
+ let user: User = serde_json::from_slice(&body_bytes)?;
+
+ assert_eq!(register_user.name, user.name);
+ assert_eq!(register_user.email, user.email);
Ok(())
}