diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-03-26 21:04:02 -0500 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-03-26 21:04:02 -0500 |
commit | ce961ca85ba96813ccdca9be1d18ee11e4e0d25c (patch) | |
tree | 4e368e35c275ea1be97473ae69aa052f4a0e98a2 /src | |
parent | fd1447999d9665866d65002b2c2317b8b150225f (diff) |
feat: add user database and registration
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 31 | ||||
-rw-r--r-- | src/error.rs | 75 | ||||
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/main.rs | 48 | ||||
-rw-r--r-- | src/model.rs | 55 | ||||
-rw-r--r-- | src/routes.rs | 167 | ||||
-rw-r--r-- | src/state.rs | 15 |
7 files changed, 244 insertions, 149 deletions
diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index dc132b8..0000000 --- a/src/config.rs +++ /dev/null @@ -1,31 +0,0 @@ -#[derive(Debug, Default, Clone)] -pub struct Config { - pub database_url: String, - pub jwt_secret: String, - pub jwt_expires_in: String, - pub jwt_maxage: i32, -} - -impl Config { - pub fn init() -> Config { - let mut config = Config::default(); - - if let Ok(database_url) = std::env::var("DATABASE_URL") { - config.database_url = database_url; - }; - - if let Ok(jwt_secret) = std::env::var("JWT_SECRET") { - config.jwt_secret = jwt_secret; - }; - - if let Ok(jwt_expires_in) = std::env::var("JWT_EXPIRED_IN") { - config.jwt_expires_in = jwt_expires_in; - }; - - if let Ok(jwt_maxage) = std::env::var("JWT_MAXAGE") { - config.jwt_maxage = jwt_maxage.parse::<i32>().unwrap(); - }; - - config - } -} diff --git a/src/error.rs b/src/error.rs index a5b48ff..54075da 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,46 +1,69 @@ +use axum::{http::StatusCode, Json}; +use serde_json::json; + pub type Result<T, E = Error> = std::result::Result<T, E>; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error(transparent)] + #[error("IO error: {0}")] IO(#[from] std::io::Error), - #[error(transparent)] - TaskJoin(#[from] tokio::task::JoinError), + #[error("Env variable error: {0}")] + Env(#[from] dotenvy::Error), - #[error(transparent)] + #[error("Axum error: {0}")] Axum(#[from] axum::Error), - #[error(transparent)] + #[error("Http error: {0}")] + Http(#[from] axum::http::Error), + + #[error("Json error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Database error: {0}")] Sqlx(#[from] sqlx::Error), - #[error(transparent)] + #[error("Migration error: {0}")] Migration(#[from] sqlx::migrate::MigrateError), - #[error("User not found: {0}")] - UserNotFound(uuid::Uuid), + #[error("Failed to hash password: {0}")] + PasswordHash(#[from] argon2::password_hash::Error), + + #[error("User not found")] + UserNotFound, + + #[error("User with that email already exists")] + EmailExists, + + #[error("Email is invalid")] + EmailInvalid, + + #[error("Password is invalid")] + PasswordInvalid, + + #[error("{0}")] + Other(String), +} + +impl From<&Error> for StatusCode { + fn from(value: &Error) -> Self { + match value { + Error::UserNotFound => StatusCode::NOT_FOUND, + Error::EmailExists => StatusCode::CONFLICT, + Error::EmailInvalid | Error::PasswordInvalid => StatusCode::UNPROCESSABLE_ENTITY, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } } impl axum::response::IntoResponse for Error { fn into_response(self) -> axum::response::Response { - use axum::{http::StatusCode, Json}; - use serde_json::json; - - match self { - Error::UserNotFound(uuid) => ( - StatusCode::BAD_REQUEST, - Json(json!({ - "status": "fail", - "message": uuid, - })), - ), - err => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": err.to_string(), - })), - ), - } + // TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html) + + Json(json!({ + "status": StatusCode::from(&self).to_string(), + "detail": self.to_string(), + })) .into_response() } } @@ -2,6 +2,6 @@ pub use error::{Error, Result}; pub use routes::router; pub mod error; +pub mod model; pub mod routes; pub mod state; -pub mod model; diff --git a/src/main.rs b/src/main.rs index f2b81ae..1edf738 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,12 @@ +use std::sync::Arc; + use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use unnamed_server::state::AppState; - -use crate::config::Config; - -mod config; +use unnamed_server::{state::AppState, Error}; #[tokio::main] #[tracing::instrument] -async fn main() -> Result<(), unnamed_server::Error> { +async fn main() -> Result<(), Error> { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -18,43 +16,15 @@ async fn main() -> Result<(), unnamed_server::Error> { .init(); let _ = dotenvy::dotenv(); + let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string()); + let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set"); - let config = Config::init(); - - let state = AppState::new(config.database_url).await?; - + let state = Arc::new(AppState::init(&database_url).await?); let app = unnamed_server::router(state); - let listener = TcpListener::bind("127.0.0.1:30000").await?; + let listener = TcpListener::bind(listen_addr).await?; tracing::info!("Server listening on http://{}", listener.local_addr()?); - axum::serve(listener, app) - .with_graceful_shutdown(shutdown_signal()) - .await - .map_err(From::from) -} - -async fn shutdown_signal() { - let ctrl_c = async { - tokio::signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } + axum::serve(listener, app).await.map_err(From::from) } diff --git a/src/model.rs b/src/model.rs index 9c1bfe6..5f6111e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,28 +1,16 @@ -use chrono::prelude::*; use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; -#[allow(non_snake_case)] -#[derive(Debug, Deserialize, sqlx::FromRow, Serialize, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)] +#[serde(rename_all = "camelCase")] pub struct User { pub id: uuid::Uuid, pub name: String, pub email: String, + #[serde(default, skip_serializing)] pub password: String, - #[serde(rename = "createdAt")] - pub created_at: Option<DateTime<Utc>>, - #[serde(rename = "updatedAt")] - pub updated_at: Option<DateTime<Utc>>, -} - -impl User { - pub fn into_query_response(self) -> axum::Json<serde_json::Value> { - axum::Json(serde_json::json!({ - "status": "success", - "data": serde_json::json!({ - "user": self - }) - })) - } + pub created_at: Option<OffsetDateTime>, + pub updated_at: Option<OffsetDateTime>, } #[derive(Debug, Serialize, Deserialize)] @@ -32,15 +20,38 @@ pub struct TokenClaims { pub exp: usize, } -#[derive(Debug, Deserialize)] -pub struct RegisterUserSchema { +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RegisterSchema { pub name: String, pub email: String, + #[serde(default, skip_serializing)] pub password: String, } -#[derive(Debug, Deserialize)] -pub struct LoginUserSchema { +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginSchema { pub email: String, + #[serde(default, skip_serializing)] pub password: String, } + +macro_rules! impl_from_superset { + ($from:tt, $to:ty, $($field:tt)*) => { + impl From<$from> for $to { + fn from(value: $from) -> Self { + let $from { + $($field)*, + .. + } = value; + + Self { + $($field)*, + } + } + } + }; +} + +impl_from_superset!(User, RegisterSchema, name, email, password); +impl_from_superset!(User, LoginSchema, email, password); +impl_from_superset!(RegisterSchema, LoginSchema, email, password); diff --git a/src/routes.rs b/src/routes.rs index 0a81317..0bf34b2 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,22 +1,33 @@ use std::sync::Arc; +use argon2::{ + password_hash::{rand_core::OsRng, SaltString}, + Argon2, PasswordHasher, +}; use axum::{ extract::State, http::{StatusCode, Uri}, + response::IntoResponse, Json, }; use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; -use crate::{model::User, state::AppState, Error}; +use crate::{ + model::{RegisterSchema, User}, + state::AppState, + Error, +}; -pub fn router(state: AppState) -> axum::Router { +#[tracing::instrument] +pub fn router(state: Arc<AppState>) -> axum::Router { axum::Router::new() // .route("/api/user", get(get_user)) .typed_get(HealthCheck::get) - .typed_get(UserId::get) + .typed_get(UserUuid::get) + .typed_post(Register::post) .fallback(fallback) - .with_state(Arc::new(state)) + .with_state(state) } #[derive(Debug, Deserialize, TypedPath)] @@ -24,7 +35,8 @@ pub fn router(state: AppState) -> axum::Router { pub struct HealthCheck; impl HealthCheck { - pub async fn get(self) -> Json<serde_json::Value> { + #[tracing::instrument] + pub async fn get(self) -> impl IntoResponse { const MESSAGE: &str = "Unnamed server"; let json_response = serde_json::json!({ @@ -38,26 +50,62 @@ impl HealthCheck { #[derive(Debug, Deserialize, TypedPath)] #[typed_path("/api/user/:uuid")] -pub struct UserId { +pub struct UserUuid { pub uuid: uuid::Uuid, } -impl UserId { - /// Get a user via their `id` - #[tracing::instrument(ret, skip(data))] - pub async fn get( - self, - State(data): State<Arc<AppState>>, - ) -> Result<Json<serde_json::Value>, Error> { +impl UserUuid { + /// Get a user with a specific `uuid` + #[tracing::instrument] + pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse { sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid) - .fetch_optional(&data.db_pool) + .fetch_optional(&state.pool) .await? - .ok_or_else(|| Error::UserNotFound(self.uuid)) - .map(User::into_query_response) + .ok_or_else(|| Error::UserNotFound) + .map(Json) + } +} + +#[derive(Debug, Deserialize, TypedPath)] +#[typed_path("/api/user/register")] +pub struct Register; + +impl Register { + #[tracing::instrument(skip(register_schema))] + pub async fn post( + self, + State(state): State<Arc<AppState>>, + Json(register_schema): Json<RegisterSchema>, + ) -> impl IntoResponse { + let exists: Option<bool> = + sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)") + .bind(register_schema.email.to_ascii_lowercase()) + .fetch_one(&state.pool) + .await?; + + if exists.is_some_and(|b| b) { + return Err(Error::EmailExists); + } + + let salt = SaltString::generate(&mut OsRng); + let hashed_password = + Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?; + + let user = sqlx::query_as!( + User, + "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *", + register_schema.name, + register_schema.email.to_ascii_lowercase(), + hashed_password.to_string() + ) + .fetch_one(&state.pool) + .await?; + + Ok((StatusCode::CREATED, Json(user))) } } -pub async fn fallback(uri: Uri) -> (StatusCode, String) { +pub async fn fallback(uri: Uri) -> impl IntoResponse { (StatusCode::NOT_FOUND, format!("Route not found: {uri}")) } @@ -65,15 +113,88 @@ pub async fn fallback(uri: Uri) -> (StatusCode, String) { mod tests { use super::*; - use axum_test::TestServer; + use axum::{ + body::Body, + http::{header, Request, StatusCode}, + }; + use http_body_util::BodyExt; + use sqlx::PgPool; + use tower::ServiceExt; + + #[sqlx::test] + async fn test_fallback(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let response = router + .oneshot( + Request::builder() + .uri("/does-not-exist") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(StatusCode::NOT_FOUND, response.status()); - #[tokio::test] - async fn test_fallback() -> Result<(), Box<dyn std::error::Error>> { - let server = TestServer::new(axum::Router::new().fallback(fallback))?; + Ok(()) + } + + #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))] + async fn test_user(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1") + .fetch_one(&state.pool) + .await?; - let response = server.get("/fallback").await; + let response = router + .oneshot( + Request::builder() + .uri(format!("/api/user/{}", user.id)) + .body(Body::empty())?, + ) + .await + .unwrap(); + + assert_eq!(StatusCode::OK, response.status()); + + Ok(()) + } - assert_eq!(StatusCode::NOT_FOUND, response.status_code()); + #[sqlx::test] + async fn test_user_register(pool: PgPool) -> Result<(), Error> { + let state = Arc::new(AppState { pool }); + let router = router(state.clone()); + + let register_user = RegisterSchema { + name: "Ford Prefect".to_string(), + email: "fprefect@heartofgold.galaxy".to_string(), + password: "42".to_string(), + }; + + let response = router + .oneshot( + Request::builder() + .uri("/api/user/register") + .method("POST") + .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Body::from( + serde_json::to_vec(&serde_json::json!(register_user)).unwrap(), + ))?, + ) + .await + .unwrap(); + + assert_eq!(StatusCode::CREATED, response.status()); + + let body_bytes = response.into_body().collect().await?.to_bytes(); + let user: User = serde_json::from_slice(&body_bytes)?; + + assert_eq!(register_user.name, user.name); + assert_eq!(register_user.email, user.email); Ok(()) } diff --git a/src/state.rs b/src/state.rs index efe2192..614688b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,18 +1,19 @@ use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; -use crate::Error; - +#[derive(Debug)] pub struct AppState { - pub db_pool: Pool<Postgres>, + pub pool: Pool<Postgres>, } impl AppState { - pub async fn new<S: AsRef<str>>(db_url: S) -> Result<Self, Error> { - let db_pool = PgPoolOptions::new() + pub async fn init(database_uri: &str) -> Result<Self, sqlx::Error> { + let pool = PgPoolOptions::new() .max_connections(10) - .connect(db_url.as_ref()) + .connect(database_uri) .await?; - Ok(Self { db_pool }) + sqlx::migrate!().run(&pool).await?; + + Ok(Self { pool }) } } |