summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-04-02 17:23:05 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-04-02 17:32:38 -0500
commitd3a09372b5b945a609cce5e28c4d4233e3b134e8 (patch)
tree9f888aa393948e1c90d76d37349ffe9738b95f24
parentb263c6637ce8b7c83e4d01d1ef2e90e195a155fb (diff)
feat: impl toml file and env config layering
-rw-r--r--Cargo.lock60
-rw-r--r--Cargo.toml2
-rw-r--r--config.toml3
-rw-r--r--src/config.rs129
-rw-r--r--src/error.rs7
-rw-r--r--src/main.rs50
6 files changed, 215 insertions, 36 deletions
diff --git a/Cargo.lock b/Cargo.lock
index d6f7678..16fb319 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -869,6 +869,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
+name = "main_error"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "155db5e86c6e45ee456bf32fad5a290ee1f7151c2faca27ea27097568da67d1a"
+
+[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1422,6 +1428,15 @@ dependencies = [
]
[[package]]
+name = "serde_spanned"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1"
+dependencies = [
+ "serde",
+]
+
+[[package]]
name = "serde_urlencoded"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1947,6 +1962,40 @@ dependencies = [
]
[[package]]
+name = "toml"
+version = "0.8.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3"
+dependencies = [
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "toml_edit",
+]
+
+[[package]]
+name = "toml_datetime"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1"
+dependencies = [
+ "serde",
+]
+
+[[package]]
+name = "toml_edit"
+version = "0.22.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4"
+dependencies = [
+ "indexmap",
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "winnow",
+]
+
+[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2086,6 +2135,7 @@ dependencies = [
"email_address",
"http-body-util",
"jsonwebtoken",
+ "main_error",
"mime",
"pgtemp",
"serde",
@@ -2094,6 +2144,7 @@ dependencies = [
"thiserror",
"time",
"tokio",
+ "toml",
"tower",
"tracing",
"tracing-subscriber",
@@ -2381,6 +2432,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
[[package]]
+name = "winnow"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
name = "zerocopy"
version = "0.7.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 1ae7a29..430f228 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,12 +12,14 @@ axum-extra = { version = "0.9.3", features = ["typed-routing", "cookie", "typed-
dotenvy = "0.15.7"
email_address = "0.2.4"
jsonwebtoken = "9.3.0"
+main_error = "0.1.2"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio", "uuid", "time"] }
thiserror = "1.0.58"
time = { version = "0.3.34", features = ["serde", "serde-human-readable"] }
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] }
+toml = "0.8.12"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.8.0", features = ["serde"] }
diff --git a/config.toml b/config.toml
new file mode 100644
index 0000000..0a490d3
--- /dev/null
+++ b/config.toml
@@ -0,0 +1,3 @@
+database_url = "postgres://localhost/unnamed"
+jwt_secret = "i-am-a-secret-token-and-i-am-proud"
+jwt_max_age = "1h"
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..0b84a5e
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,129 @@
+use std::{net::SocketAddr, sync::Arc};
+
+use axum::Router;
+use serde::{Deserialize, Serialize};
+use time::Duration;
+use tokio::net::TcpListener;
+use unnamed_server::{state::AppState, Error};
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Config {
+ listen_addr: Option<SocketAddr>,
+ jwt_max_age: Option<String>,
+ jwt_secret: Option<String>,
+ database_url: Option<String>,
+}
+
+impl Config {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> {
+ match std::fs::read_to_string(file) {
+ Ok(s) => Ok(self.merge(toml::from_str(&s)?)),
+ Err(err) => {
+ tracing::warn!("Error reading config file: {err}");
+ Ok(self)
+ }
+ }
+ }
+
+ pub fn env(self, prefix: &str) -> Result<Self, Error> {
+ Ok(self.merge(Self {
+ listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR"))
+ .ok()
+ .and_then(|v| v.parse().ok()),
+ jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(),
+ jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(),
+ database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(),
+ }))
+ }
+
+ pub async fn build(self) -> Result<(TcpListener, Router), Error> {
+ macro_rules! try_extract {
+ ($($i:ident),+) => {
+ $(let Some($i) = self.$i else {
+ return Err(Error::Config(format!("Missing value: {}", stringify!($i))))
+ };)+
+ };
+ }
+
+ try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url);
+
+ let listener = TcpListener::bind(listen_addr).await?;
+ let pool = init_db(&database_url).await?;
+ let jwt_max_age = parse_duration(jwt_max_age)?;
+ let app_state = Arc::new(AppState {
+ pool,
+ jwt_secret,
+ jwt_max_age,
+ });
+ let app = unnamed_server::init_router(app_state);
+
+ Ok((listener, app))
+ }
+
+ /// Merge self with other, overwriting any existing values on self with other's.
+ fn merge(self, other: Self) -> Self {
+ Self {
+ listen_addr: other.listen_addr.or(self.listen_addr),
+ jwt_max_age: other.jwt_max_age.or(self.jwt_max_age),
+ jwt_secret: other.jwt_secret.or(self.jwt_secret),
+ database_url: other.database_url.or(self.database_url),
+ }
+ }
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))),
+ jwt_max_age: Some("1h".to_string()),
+ jwt_secret: None,
+ database_url: None,
+ }
+ }
+}
+
+async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> {
+ let pool = sqlx::postgres::PgPoolOptions::new()
+ .max_connections(10)
+ .connect(uri)
+ .await?;
+
+ sqlx::migrate!().run(&pool).await?;
+
+ Ok(pool)
+}
+
+fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> {
+ let chars = &mut s.as_ref().chars();
+ let mut nums: i64 = 0;
+ let mut unit = String::new();
+
+ for c in chars.by_ref() {
+ if c.is_ascii_digit() {
+ nums = nums * 10 + c.to_digit(10).unwrap() as i64;
+ } else {
+ unit.push(c);
+ break;
+ }
+ }
+
+ unit.extend(chars);
+
+ if "weeks".contains(&unit) {
+ Ok(Duration::weeks(nums))
+ } else if "days".contains(&unit) {
+ Ok(Duration::days(nums))
+ } else if "hours".contains(&unit) {
+ Ok(Duration::hours(nums))
+ } else if "minutes".contains(&unit) {
+ Ok(Duration::minutes(nums))
+ } else if "seconds".contains(&unit) {
+ Ok(Duration::seconds(nums))
+ } else {
+ Err(Error::Config("Invalid jwt_max_age".to_string()))
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index 2824e49..48360de 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -8,8 +8,11 @@ pub enum Error {
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
- #[error("Env variable error: {0}")]
- Env(#[from] dotenvy::Error),
+ #[error("Config file error: {0}")]
+ Toml(#[from] toml::de::Error),
+
+ #[error("Config error: {0}")]
+ Config(String),
#[error("Axum error: {0}")]
Axum(#[from] axum::Error),
diff --git a/src/main.rs b/src/main.rs
index a926916..353c7fc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,13 +1,14 @@
-use std::sync::Arc;
-
-use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
-use tokio::net::TcpListener;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
-use unnamed_server::{state::AppState, Error};
+
+use crate::config::Config;
+
+mod config;
#[tokio::main]
#[tracing::instrument]
-async fn main() -> Result<(), Error> {
+async fn main() -> Result<(), main_error::MainError> {
+ let _ = dotenvy::dotenv();
+
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
@@ -16,36 +17,17 @@ async fn main() -> Result<(), Error> {
.with(tracing_subscriber::fmt::layer())
.init();
- // TODO: Migrate all of these into a struct parsed from env, cli, and file.
- let _ = dotenvy::dotenv();
- let listen_addr = std::env::var("ADDRESS").unwrap_or("127.0.0.1:30000".to_string());
- let jwt_max_age: time::Duration = time::Duration::HOUR;
- // serde_json::from_str(&std::env::var("JWT_MAX_AGE").unwrap_or_else(|_| "1h".to_string()))?;
- let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET is not set");
- let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL is not set");
-
- let pool = init_db(&database_url).await?;
- let state = Arc::new(AppState {
- pool,
- jwt_secret,
- jwt_max_age,
- });
- let app = unnamed_server::init_router(state);
-
- let listener = TcpListener::bind(listen_addr).await?;
+ let config_file = std::env::args()
+ .nth(1)
+ .unwrap_or("/etc/unnamed_app.toml".to_string());
- tracing::info!("Server listening on http://{}", listener.local_addr()?);
-
- axum::serve(listener, app).await.map_err(From::from)
-}
-
-async fn init_db(uri: &str) -> Result<Pool<Postgres>, sqlx::Error> {
- let pool = PgPoolOptions::new()
- .max_connections(10)
- .connect(uri)
+ let (listener, router) = Config::new()
+ .file(config_file)?
+ .env("UNNAMED_")?
+ .build()
.await?;
- sqlx::migrate!().run(&pool).await?;
+ tracing::info!("Server listening on http://{}", listener.local_addr()?);
- Ok(pool)
+ axum::serve(listener, router).await.map_err(From::from)
}