summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.rs31
-rw-r--r--src/error.rs75
-rw-r--r--src/lib.rs2
-rw-r--r--src/main.rs48
-rw-r--r--src/model.rs55
-rw-r--r--src/routes.rs167
-rw-r--r--src/state.rs15
7 files changed, 244 insertions, 149 deletions
diff --git a/src/config.rs b/src/config.rs
deleted file mode 100644
index dc132b8..0000000
--- a/src/config.rs
+++ /dev/null
@@ -1,31 +0,0 @@
-#[derive(Debug, Default, Clone)]
-pub struct Config {
- pub database_url: String,
- pub jwt_secret: String,
- pub jwt_expires_in: String,
- pub jwt_maxage: i32,
-}
-
-impl Config {
- pub fn init() -> Config {
- let mut config = Config::default();
-
- if let Ok(database_url) = std::env::var("DATABASE_URL") {
- config.database_url = database_url;
- };
-
- if let Ok(jwt_secret) = std::env::var("JWT_SECRET") {
- config.jwt_secret = jwt_secret;
- };
-
- if let Ok(jwt_expires_in) = std::env::var("JWT_EXPIRED_IN") {
- config.jwt_expires_in = jwt_expires_in;
- };
-
- if let Ok(jwt_maxage) = std::env::var("JWT_MAXAGE") {
- config.jwt_maxage = jwt_maxage.parse::<i32>().unwrap();
- };
-
- config
- }
-}
diff --git a/src/error.rs b/src/error.rs
index a5b48ff..54075da 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,46 +1,69 @@
+use axum::{http::StatusCode, Json};
+use serde_json::json;
+
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
- #[error(transparent)]
+ #[error("IO error: {0}")]
IO(#[from] std::io::Error),
- #[error(transparent)]
- TaskJoin(#[from] tokio::task::JoinError),
+ #[error("Env variable error: {0}")]
+ Env(#[from] dotenvy::Error),
- #[error(transparent)]
+ #[error("Axum error: {0}")]
Axum(#[from] axum::Error),
- #[error(transparent)]
+ #[error("Http error: {0}")]
+ Http(#[from] axum::http::Error),
+
+ #[error("Json error: {0}")]
+ Json(#[from] serde_json::Error),
+
+ #[error("Database error: {0}")]
Sqlx(#[from] sqlx::Error),
- #[error(transparent)]
+ #[error("Migration error: {0}")]
Migration(#[from] sqlx::migrate::MigrateError),
- #[error("User not found: {0}")]
- UserNotFound(uuid::Uuid),
+ #[error("Failed to hash password: {0}")]
+ PasswordHash(#[from] argon2::password_hash::Error),
+
+ #[error("User not found")]
+ UserNotFound,
+
+ #[error("User with that email already exists")]
+ EmailExists,
+
+ #[error("Email is invalid")]
+ EmailInvalid,
+
+ #[error("Password is invalid")]
+ PasswordInvalid,
+
+ #[error("{0}")]
+ Other(String),
+}
+
+impl From<&Error> for StatusCode {
+ fn from(value: &Error) -> Self {
+ match value {
+ Error::UserNotFound => StatusCode::NOT_FOUND,
+ Error::EmailExists => StatusCode::CONFLICT,
+ Error::EmailInvalid | Error::PasswordInvalid => StatusCode::UNPROCESSABLE_ENTITY,
+ _ => StatusCode::INTERNAL_SERVER_ERROR,
+ }
+ }
}
impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
- use axum::{http::StatusCode, Json};
- use serde_json::json;
-
- match self {
- Error::UserNotFound(uuid) => (
- StatusCode::BAD_REQUEST,
- Json(json!({
- "status": "fail",
- "message": uuid,
- })),
- ),
- err => (
- StatusCode::INTERNAL_SERVER_ERROR,
- Json(json!({
- "error": err.to_string(),
- })),
- ),
- }
+ // TODO: implement [rfc7807](https://www.rfc-editor.org/rfc/rfc7807.html)
+
+ Json(json!({
+ "status": StatusCode::from(&self).to_string(),
+ "detail": self.to_string(),
+ }))
.into_response()
}
}
diff --git a/src/lib.rs b/src/lib.rs
index ad634d0..231c1c1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,6 +2,6 @@ pub use error::{Error, Result};
pub use routes::router;
pub mod error;
+pub mod model;
pub mod routes;
pub mod state;
-pub mod model;
diff --git a/src/main.rs b/src/main.rs
index f2b81ae..1edf738 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,14 +1,12 @@
+use std::sync::Arc;
+
use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
-use unnamed_server::state::AppState;
-
-use crate::config::Config;
-
-mod config;
+use unnamed_server::{state::AppState, Error};
#[tokio::main]
#[tracing::instrument]
-async fn main() -> Result<(), unnamed_server::Error> {
+async fn main() -> Result<(), Error> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
@@ -18,43 +16,15 @@ async fn main() -> Result<(), unnamed_server::Error> {
.init();
let _ = dotenvy::dotenv();
+ let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string());
+ let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set");
- let config = Config::init();
-
- let state = AppState::new(config.database_url).await?;
-
+ let state = Arc::new(AppState::init(&database_url).await?);
let app = unnamed_server::router(state);
- let listener = TcpListener::bind("127.0.0.1:30000").await?;
+ let listener = TcpListener::bind(listen_addr).await?;
tracing::info!("Server listening on http://{}", listener.local_addr()?);
- axum::serve(listener, app)
- .with_graceful_shutdown(shutdown_signal())
- .await
- .map_err(From::from)
-}
-
-async fn shutdown_signal() {
- let ctrl_c = async {
- tokio::signal::ctrl_c()
- .await
- .expect("failed to install Ctrl+C handler");
- };
-
- #[cfg(unix)]
- let terminate = async {
- tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
- .expect("failed to install signal handler")
- .recv()
- .await;
- };
-
- #[cfg(not(unix))]
- let terminate = std::future::pending::<()>();
-
- tokio::select! {
- _ = ctrl_c => {},
- _ = terminate => {},
- }
+ axum::serve(listener, app).await.map_err(From::from)
}
diff --git a/src/model.rs b/src/model.rs
index 9c1bfe6..5f6111e 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -1,28 +1,16 @@
-use chrono::prelude::*;
use serde::{Deserialize, Serialize};
+use time::OffsetDateTime;
-#[allow(non_snake_case)]
-#[derive(Debug, Deserialize, sqlx::FromRow, Serialize, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
+#[serde(rename_all = "camelCase")]
pub struct User {
pub id: uuid::Uuid,
pub name: String,
pub email: String,
+ #[serde(default, skip_serializing)]
pub password: String,
- #[serde(rename = "createdAt")]
- pub created_at: Option<DateTime<Utc>>,
- #[serde(rename = "updatedAt")]
- pub updated_at: Option<DateTime<Utc>>,
-}
-
-impl User {
- pub fn into_query_response(self) -> axum::Json<serde_json::Value> {
- axum::Json(serde_json::json!({
- "status": "success",
- "data": serde_json::json!({
- "user": self
- })
- }))
- }
+ pub created_at: Option<OffsetDateTime>,
+ pub updated_at: Option<OffsetDateTime>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -32,15 +20,38 @@ pub struct TokenClaims {
pub exp: usize,
}
-#[derive(Debug, Deserialize)]
-pub struct RegisterUserSchema {
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub struct RegisterSchema {
pub name: String,
pub email: String,
+ #[serde(default, skip_serializing)]
pub password: String,
}
-#[derive(Debug, Deserialize)]
-pub struct LoginUserSchema {
+#[derive(Debug, Serialize, Deserialize)]
+pub struct LoginSchema {
pub email: String,
+ #[serde(default, skip_serializing)]
pub password: String,
}
+
+macro_rules! impl_from_superset {
+ ($from:tt, $to:ty, $($field:tt)*) => {
+ impl From<$from> for $to {
+ fn from(value: $from) -> Self {
+ let $from {
+ $($field)*,
+ ..
+ } = value;
+
+ Self {
+ $($field)*,
+ }
+ }
+ }
+ };
+}
+
+impl_from_superset!(User, RegisterSchema, name, email, password);
+impl_from_superset!(User, LoginSchema, email, password);
+impl_from_superset!(RegisterSchema, LoginSchema, email, password);
diff --git a/src/routes.rs b/src/routes.rs
index 0a81317..0bf34b2 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -1,22 +1,33 @@
use std::sync::Arc;
+use argon2::{
+ password_hash::{rand_core::OsRng, SaltString},
+ Argon2, PasswordHasher,
+};
use axum::{
extract::State,
http::{StatusCode, Uri},
+ response::IntoResponse,
Json,
};
use axum_extra::routing::{RouterExt, TypedPath};
use serde::Deserialize;
-use crate::{model::User, state::AppState, Error};
+use crate::{
+ model::{RegisterSchema, User},
+ state::AppState,
+ Error,
+};
-pub fn router(state: AppState) -> axum::Router {
+#[tracing::instrument]
+pub fn router(state: Arc<AppState>) -> axum::Router {
axum::Router::new()
// .route("/api/user", get(get_user))
.typed_get(HealthCheck::get)
- .typed_get(UserId::get)
+ .typed_get(UserUuid::get)
+ .typed_post(Register::post)
.fallback(fallback)
- .with_state(Arc::new(state))
+ .with_state(state)
}
#[derive(Debug, Deserialize, TypedPath)]
@@ -24,7 +35,8 @@ pub fn router(state: AppState) -> axum::Router {
pub struct HealthCheck;
impl HealthCheck {
- pub async fn get(self) -> Json<serde_json::Value> {
+ #[tracing::instrument]
+ pub async fn get(self) -> impl IntoResponse {
const MESSAGE: &str = "Unnamed server";
let json_response = serde_json::json!({
@@ -38,26 +50,62 @@ impl HealthCheck {
#[derive(Debug, Deserialize, TypedPath)]
#[typed_path("/api/user/:uuid")]
-pub struct UserId {
+pub struct UserUuid {
pub uuid: uuid::Uuid,
}
-impl UserId {
- /// Get a user via their `id`
- #[tracing::instrument(ret, skip(data))]
- pub async fn get(
- self,
- State(data): State<Arc<AppState>>,
- ) -> Result<Json<serde_json::Value>, Error> {
+impl UserUuid {
+ /// Get a user with a specific `uuid`
+ #[tracing::instrument]
+ pub async fn get(self, State(state): State<Arc<AppState>>) -> impl IntoResponse {
sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", self.uuid)
- .fetch_optional(&data.db_pool)
+ .fetch_optional(&state.pool)
.await?
- .ok_or_else(|| Error::UserNotFound(self.uuid))
- .map(User::into_query_response)
+ .ok_or_else(|| Error::UserNotFound)
+ .map(Json)
+ }
+}
+
+#[derive(Debug, Deserialize, TypedPath)]
+#[typed_path("/api/user/register")]
+pub struct Register;
+
+impl Register {
+ #[tracing::instrument(skip(register_schema))]
+ pub async fn post(
+ self,
+ State(state): State<Arc<AppState>>,
+ Json(register_schema): Json<RegisterSchema>,
+ ) -> impl IntoResponse {
+ let exists: Option<bool> =
+ sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)")
+ .bind(register_schema.email.to_ascii_lowercase())
+ .fetch_one(&state.pool)
+ .await?;
+
+ if exists.is_some_and(|b| b) {
+ return Err(Error::EmailExists);
+ }
+
+ let salt = SaltString::generate(&mut OsRng);
+ let hashed_password =
+ Argon2::default().hash_password(register_schema.password.as_bytes(), &salt)?;
+
+ let user = sqlx::query_as!(
+ User,
+ "INSERT INTO users (name,email,password) VALUES ($1, $2, $3) RETURNING *",
+ register_schema.name,
+ register_schema.email.to_ascii_lowercase(),
+ hashed_password.to_string()
+ )
+ .fetch_one(&state.pool)
+ .await?;
+
+ Ok((StatusCode::CREATED, Json(user)))
}
}
-pub async fn fallback(uri: Uri) -> (StatusCode, String) {
+pub async fn fallback(uri: Uri) -> impl IntoResponse {
(StatusCode::NOT_FOUND, format!("Route not found: {uri}"))
}
@@ -65,15 +113,88 @@ pub async fn fallback(uri: Uri) -> (StatusCode, String) {
mod tests {
use super::*;
- use axum_test::TestServer;
+ use axum::{
+ body::Body,
+ http::{header, Request, StatusCode},
+ };
+ use http_body_util::BodyExt;
+ use sqlx::PgPool;
+ use tower::ServiceExt;
+
+ #[sqlx::test]
+ async fn test_fallback(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri("/does-not-exist")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::NOT_FOUND, response.status());
- #[tokio::test]
- async fn test_fallback() -> Result<(), Box<dyn std::error::Error>> {
- let server = TestServer::new(axum::Router::new().fallback(fallback))?;
+ Ok(())
+ }
+
+ #[sqlx::test(fixtures(path = "../fixtures", scripts("users")))]
+ async fn test_user(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let user = sqlx::query_as!(User, "SELECT * FROM users LIMIT 1")
+ .fetch_one(&state.pool)
+ .await?;
- let response = server.get("/fallback").await;
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri(format!("/api/user/{}", user.id))
+ .body(Body::empty())?,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::OK, response.status());
+
+ Ok(())
+ }
- assert_eq!(StatusCode::NOT_FOUND, response.status_code());
+ #[sqlx::test]
+ async fn test_user_register(pool: PgPool) -> Result<(), Error> {
+ let state = Arc::new(AppState { pool });
+ let router = router(state.clone());
+
+ let register_user = RegisterSchema {
+ name: "Ford Prefect".to_string(),
+ email: "fprefect@heartofgold.galaxy".to_string(),
+ password: "42".to_string(),
+ };
+
+ let response = router
+ .oneshot(
+ Request::builder()
+ .uri("/api/user/register")
+ .method("POST")
+ .header(header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(Body::from(
+ serde_json::to_vec(&serde_json::json!(register_user)).unwrap(),
+ ))?,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(StatusCode::CREATED, response.status());
+
+ let body_bytes = response.into_body().collect().await?.to_bytes();
+ let user: User = serde_json::from_slice(&body_bytes)?;
+
+ assert_eq!(register_user.name, user.name);
+ assert_eq!(register_user.email, user.email);
Ok(())
}
diff --git a/src/state.rs b/src/state.rs
index efe2192..614688b 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,18 +1,19 @@
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
-use crate::Error;
-
+#[derive(Debug)]
pub struct AppState {
- pub db_pool: Pool<Postgres>,
+ pub pool: Pool<Postgres>,
}
impl AppState {
- pub async fn new<S: AsRef<str>>(db_url: S) -> Result<Self, Error> {
- let db_pool = PgPoolOptions::new()
+ pub async fn init(database_uri: &str) -> Result<Self, sqlx::Error> {
+ let pool = PgPoolOptions::new()
.max_connections(10)
- .connect(db_url.as_ref())
+ .connect(database_uri)
.await?;
- Ok(Self { db_pool })
+ sqlx::migrate!().run(&pool).await?;
+
+ Ok(Self { pool })
}
}