diff options
-rw-r--r-- | Cargo.lock | 60 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | config.toml | 3 | ||||
-rw-r--r-- | src/config.rs | 129 | ||||
-rw-r--r-- | src/error.rs | 7 | ||||
-rw-r--r-- | src/main.rs | 50 |
6 files changed, 215 insertions, 36 deletions
@@ -869,6 +869,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] +name = "main_error" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "155db5e86c6e45ee456bf32fad5a290ee1f7151c2faca27ea27097568da67d1a" + +[[package]] name = "matchers" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1422,6 +1428,15 @@ dependencies = [ ] [[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + +[[package]] name = "serde_urlencoded" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1947,6 +1962,40 @@ dependencies = [ ] [[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] name = "tower" version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2086,6 +2135,7 @@ dependencies = [ "email_address", "http-body-util", "jsonwebtoken", + "main_error", "mime", "pgtemp", "serde", @@ -2094,6 +2144,7 @@ dependencies = [ "thiserror", "time", "tokio", + "toml", "tower", "tracing", "tracing-subscriber", @@ -2381,6 +2432,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] +name = "winnow" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +dependencies = [ + "memchr", +] + +[[package]] name = "zerocopy" version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -12,12 +12,14 @@ axum-extra = { version = "0.9.3", features = ["typed-routing", "cookie", "typed- dotenvy = "0.15.7" email_address = "0.2.4" jsonwebtoken = "9.3.0" +main_error = "0.1.2" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio", "uuid", "time"] } thiserror = "1.0.58" time = { version = "0.3.34", features = ["serde", "serde-human-readable"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] } +toml = "0.8.12" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } uuid = { version = "1.8.0", features = ["serde"] } diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..0a490d3 --- /dev/null +++ b/config.toml @@ -0,0 +1,3 @@ +database_url = "postgres://localhost/unnamed" +jwt_secret = "i-am-a-secret-token-and-i-am-proud" +jwt_max_age = "1h" diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..0b84a5e --- /dev/null +++ b/src/config.rs @@ -0,0 +1,129 @@ +use std::{net::SocketAddr, sync::Arc}; + +use axum::Router; +use serde::{Deserialize, Serialize}; +use time::Duration; +use tokio::net::TcpListener; +use unnamed_server::{state::AppState, Error}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + listen_addr: Option<SocketAddr>, + jwt_max_age: Option<String>, + jwt_secret: Option<String>, + database_url: Option<String>, +} + +impl Config { + pub fn new() -> Self { + Self::default() + } + + pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> { + match std::fs::read_to_string(file) { + Ok(s) => Ok(self.merge(toml::from_str(&s)?)), + Err(err) => { + tracing::warn!("Error reading config file: {err}"); + Ok(self) + } + } + } + + pub fn env(self, prefix: &str) -> Result<Self, Error> { + Ok(self.merge(Self { + listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR")) + .ok() + .and_then(|v| v.parse().ok()), + jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(), + jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(), + database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(), + })) + } + + pub async fn build(self) -> Result<(TcpListener, Router), Error> { + macro_rules! try_extract { + ($($i:ident),+) => { + $(let Some($i) = self.$i else { + return Err(Error::Config(format!("Missing value: {}", stringify!($i)))) + };)+ + }; + } + + try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url); + + let listener = TcpListener::bind(listen_addr).await?; + let pool = init_db(&database_url).await?; + let jwt_max_age = parse_duration(jwt_max_age)?; + let app_state = Arc::new(AppState { + pool, + jwt_secret, + jwt_max_age, + }); + let app = unnamed_server::init_router(app_state); + + Ok((listener, app)) + } + + /// Merge self with other, overwriting any existing values on self with other's. + fn merge(self, other: Self) -> Self { + Self { + listen_addr: other.listen_addr.or(self.listen_addr), + jwt_max_age: other.jwt_max_age.or(self.jwt_max_age), + jwt_secret: other.jwt_secret.or(self.jwt_secret), + database_url: other.database_url.or(self.database_url), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))), + jwt_max_age: Some("1h".to_string()), + jwt_secret: None, + database_url: None, + } + } +} + +async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(10) + .connect(uri) + .await?; + + sqlx::migrate!().run(&pool).await?; + + Ok(pool) +} + +fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> { + let chars = &mut s.as_ref().chars(); + let mut nums: i64 = 0; + let mut unit = String::new(); + + for c in chars.by_ref() { + if c.is_ascii_digit() { + nums = nums * 10 + c.to_digit(10).unwrap() as i64; + } else { + unit.push(c); + break; + } + } + + unit.extend(chars); + + if "weeks".contains(&unit) { + Ok(Duration::weeks(nums)) + } else if "days".contains(&unit) { + Ok(Duration::days(nums)) + } else if "hours".contains(&unit) { + Ok(Duration::hours(nums)) + } else if "minutes".contains(&unit) { + Ok(Duration::minutes(nums)) + } else if "seconds".contains(&unit) { + Ok(Duration::seconds(nums)) + } else { + Err(Error::Config("Invalid jwt_max_age".to_string())) + } +} diff --git a/src/error.rs b/src/error.rs index 2824e49..48360de 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,8 +8,11 @@ pub enum Error { #[error("IO error: {0}")] IO(#[from] std::io::Error), - #[error("Env variable error: {0}")] - Env(#[from] dotenvy::Error), + #[error("Config file error: {0}")] + Toml(#[from] toml::de::Error), + + #[error("Config error: {0}")] + Config(String), #[error("Axum error: {0}")] Axum(#[from] axum::Error), diff --git a/src/main.rs b/src/main.rs index a926916..353c7fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ -use std::sync::Arc; - -use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; -use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use unnamed_server::{state::AppState, Error}; + +use crate::config::Config; + +mod config; #[tokio::main] #[tracing::instrument] -async fn main() -> Result<(), Error> { +async fn main() -> Result<(), main_error::MainError> { + let _ = dotenvy::dotenv(); + tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -16,36 +17,17 @@ async fn main() -> Result<(), Error> { .with(tracing_subscriber::fmt::layer()) .init(); - // TODO: Migrate all of these into a struct parsed from env, cli, and file. - let _ = dotenvy::dotenv(); - let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string()); - let jwt_max_age: time::Duration = time::Duration::HOUR; - // serde_json::from_str(&std::env::var("JWT_MAX_AGE").unwrap_or_else(|_| "1h".to_string()))?; - let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET is not set"); - let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set"); - - let pool = init_db(&database_url).await?; - let state = Arc::new(AppState { - pool, - jwt_secret, - jwt_max_age, - }); - let app = unnamed_server::init_router(state); - - let listener = TcpListener::bind(listen_addr).await?; + let config_file = std::env::args() + .nth(1) + .unwrap_or("/etc/unnamed_app.toml".to_string()); - tracing::info!("Server listening on http://{}", listener.local_addr()?); - - axum::serve(listener, app).await.map_err(From::from) -} - -async fn init_db(uri: &str) -> Result<Pool<Postgres>, sqlx::Error> { - let pool = PgPoolOptions::new() - .max_connections(10) - .connect(uri) + let (listener, router) = Config::new() + .file(config_file)? + .env("UNNAMED_")? + .build() .await?; - sqlx::migrate!().run(&pool).await?; + tracing::info!("Server listening on http://{}", listener.local_addr()?); - Ok(pool) + axum::serve(listener, router).await.map_err(From::from) } |