diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-02 17:23:05 -0500 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-04-02 17:32:38 -0500 |
commit | d3a09372b5b945a609cce5e28c4d4233e3b134e8 (patch) | |
tree | 9f888aa393948e1c90d76d37349ffe9738b95f24 /src | |
parent | b263c6637ce8b7c83e4d01d1ef2e90e195a155fb (diff) |
feat: impl toml file and env config layering
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 129 | ||||
-rw-r--r-- | src/error.rs | 7 | ||||
-rw-r--r-- | src/main.rs | 50 |
3 files changed, 150 insertions, 36 deletions
diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..0b84a5e --- /dev/null +++ b/src/config.rs @@ -0,0 +1,129 @@ +use std::{net::SocketAddr, sync::Arc}; + +use axum::Router; +use serde::{Deserialize, Serialize}; +use time::Duration; +use tokio::net::TcpListener; +use unnamed_server::{state::AppState, Error}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + listen_addr: Option<SocketAddr>, + jwt_max_age: Option<String>, + jwt_secret: Option<String>, + database_url: Option<String>, +} + +impl Config { + pub fn new() -> Self { + Self::default() + } + + pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> { + match std::fs::read_to_string(file) { + Ok(s) => Ok(self.merge(toml::from_str(&s)?)), + Err(err) => { + tracing::warn!("Error reading config file: {err}"); + Ok(self) + } + } + } + + pub fn env(self, prefix: &str) -> Result<Self, Error> { + Ok(self.merge(Self { + listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR")) + .ok() + .and_then(|v| v.parse().ok()), + jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(), + jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(), + database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(), + })) + } + + pub async fn build(self) -> Result<(TcpListener, Router), Error> { + macro_rules! try_extract { + ($($i:ident),+) => { + $(let Some($i) = self.$i else { + return Err(Error::Config(format!("Missing value: {}", stringify!($i)))) + };)+ + }; + } + + try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url); + + let listener = TcpListener::bind(listen_addr).await?; + let pool = init_db(&database_url).await?; + let jwt_max_age = parse_duration(jwt_max_age)?; + let app_state = Arc::new(AppState { + pool, + jwt_secret, + jwt_max_age, + }); + let app = unnamed_server::init_router(app_state); + + Ok((listener, app)) + } + + /// Merge self with other, overwriting any existing values on self with other's. + fn merge(self, other: Self) -> Self { + Self { + listen_addr: other.listen_addr.or(self.listen_addr), + jwt_max_age: other.jwt_max_age.or(self.jwt_max_age), + jwt_secret: other.jwt_secret.or(self.jwt_secret), + database_url: other.database_url.or(self.database_url), + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))), + jwt_max_age: Some("1h".to_string()), + jwt_secret: None, + database_url: None, + } + } +} + +async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(10) + .connect(uri) + .await?; + + sqlx::migrate!().run(&pool).await?; + + Ok(pool) +} + +fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> { + let chars = &mut s.as_ref().chars(); + let mut nums: i64 = 0; + let mut unit = String::new(); + + for c in chars.by_ref() { + if c.is_ascii_digit() { + nums = nums * 10 + c.to_digit(10).unwrap() as i64; + } else { + unit.push(c); + break; + } + } + + unit.extend(chars); + + if "weeks".contains(&unit) { + Ok(Duration::weeks(nums)) + } else if "days".contains(&unit) { + Ok(Duration::days(nums)) + } else if "hours".contains(&unit) { + Ok(Duration::hours(nums)) + } else if "minutes".contains(&unit) { + Ok(Duration::minutes(nums)) + } else if "seconds".contains(&unit) { + Ok(Duration::seconds(nums)) + } else { + Err(Error::Config("Invalid jwt_max_age".to_string())) + } +} diff --git a/src/error.rs b/src/error.rs index 2824e49..48360de 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,8 +8,11 @@ pub enum Error { #[error("IO error: {0}")] IO(#[from] std::io::Error), - #[error("Env variable error: {0}")] - Env(#[from] dotenvy::Error), + #[error("Config file error: {0}")] + Toml(#[from] toml::de::Error), + + #[error("Config error: {0}")] + Config(String), #[error("Axum error: {0}")] Axum(#[from] axum::Error), diff --git a/src/main.rs b/src/main.rs index a926916..353c7fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ -use std::sync::Arc; - -use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; -use tokio::net::TcpListener; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use unnamed_server::{state::AppState, Error}; + +use crate::config::Config; + +mod config; #[tokio::main] #[tracing::instrument] -async fn main() -> Result<(), Error> { +async fn main() -> Result<(), main_error::MainError> { + let _ = dotenvy::dotenv(); + tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -16,36 +17,17 @@ async fn main() -> Result<(), Error> { .with(tracing_subscriber::fmt::layer()) .init(); - // TODO: Migrate all of these into a struct parsed from env, cli, and file. - let _ = dotenvy::dotenv(); - let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string()); - let jwt_max_age: time::Duration = time::Duration::HOUR; - // serde_json::from_str(&std::env::var("JWT_MAX_AGE").unwrap_or_else(|_| "1h".to_string()))?; - let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET is not set"); - let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set"); - - let pool = init_db(&database_url).await?; - let state = Arc::new(AppState { - pool, - jwt_secret, - jwt_max_age, - }); - let app = unnamed_server::init_router(state); - - let listener = TcpListener::bind(listen_addr).await?; + let config_file = std::env::args() + .nth(1) + .unwrap_or("/etc/unnamed_app.toml".to_string()); - tracing::info!("Server listening on http://{}", listener.local_addr()?); - - axum::serve(listener, app).await.map_err(From::from) -} - -async fn init_db(uri: &str) -> Result<Pool<Postgres>, sqlx::Error> { - let pool = PgPoolOptions::new() - .max_connections(10) - .connect(uri) + let (listener, router) = Config::new() + .file(config_file)? + .env("UNNAMED_")? + .build() .await?; - sqlx::migrate!().run(&pool).await?; + tracing::info!("Server listening on http://{}", listener.local_addr()?); - Ok(pool) + axum::serve(listener, router).await.map_err(From::from) } |