diff options
Diffstat (limited to 'src/config.rs')
-rw-r--r-- | src/config.rs | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..0b84a5e --- /dev/null +++ b/src/config.rs @@ -0,0 +1,129 @@ +use std::{net::SocketAddr, sync::Arc}; + +use axum::Router; +use serde::{Deserialize, Serialize}; +use time::Duration; +use tokio::net::TcpListener; +use unnamed_server::{state::AppState, Error}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + listen_addr: Option<SocketAddr>, + jwt_max_age: Option<String>, + jwt_secret: Option<String>, + database_url: Option<String>, +} + +impl Config { + pub fn new() -> Self { + Self::default() + } + + pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> { + match std::fs::read_to_string(file) { + Ok(s) => Ok(self.merge(toml::from_str(&s)?)), + Err(err) => { + tracing::warn!("Error reading config file: {err}"); + Ok(self) + } + } + } + + pub fn env(self, prefix: &str) -> Result<Self, Error> { + Ok(self.merge(Self { + listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR")) + .ok() + .and_then(|v| v.parse().ok()), + jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(), + jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(), + database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(), + })) + } + + pub async fn build(self) -> Result<(TcpListener, Router), Error> { + macro_rules! try_extract { + ($($i:ident),+) => { + $(let Some($i) = self.$i else { + return Err(Error::Config(format!("Missing value: {}", stringify!($i)))) + };)+ + }; + } + + try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url); + + let listener = TcpListener::bind(listen_addr).await?; + let pool = init_db(&database_url).await?; + let jwt_max_age = parse_duration(jwt_max_age)?; + let app_state = Arc::new(AppState { + pool, + jwt_secret, + jwt_max_age, + }); + let app = unnamed_server::init_router(app_state); + + Ok((listener, app)) + } + + /// Merge self with other, overwriting any existing values on self with other's. + fn merge(self, other: Self) -> Self { + Self { + listen_addr: other.listen_addr.or(self.listen_addr), + jwt_max_age: other.jwt_max_age.or(self.jwt_max_age), + jwt_secret: other.jwt_secret.or(self.jwt_secret), + database_url: other.database_url.or(self.database_url), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))), + jwt_max_age: Some("1h".to_string()), + jwt_secret: None, + database_url: None, + } + } +} + +async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(10) + .connect(uri) + .await?; + + sqlx::migrate!().run(&pool).await?; + + Ok(pool) +} + +fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> { + let chars = &mut s.as_ref().chars(); + let mut nums: i64 = 0; + let mut unit = String::new(); + + for c in chars.by_ref() { + if c.is_ascii_digit() { + nums = nums * 10 + c.to_digit(10).unwrap() as i64; + } else { + unit.push(c); + break; + } + } + + unit.extend(chars); + + if "weeks".contains(&unit) { + Ok(Duration::weeks(nums)) + } else if "days".contains(&unit) { + Ok(Duration::days(nums)) + } else if "hours".contains(&unit) { + Ok(Duration::hours(nums)) + } else if "minutes".contains(&unit) { + Ok(Duration::minutes(nums)) + } else if "seconds".contains(&unit) { + Ok(Duration::seconds(nums)) + } else { + Err(Error::Config("Invalid jwt_max_age".to_string())) + } +} |