summaryrefslogtreecommitdiffstats
path: root/src/config.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/config.rs')
-rw-r--r--src/config.rs129
1 files changed, 129 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..0b84a5e
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,129 @@
+use std::{net::SocketAddr, sync::Arc};
+
+use axum::Router;
+use serde::{Deserialize, Serialize};
+use time::Duration;
+use tokio::net::TcpListener;
+use unnamed_server::{state::AppState, Error};
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Config {
+ listen_addr: Option<SocketAddr>,
+ jwt_max_age: Option<String>,
+ jwt_secret: Option<String>,
+ database_url: Option<String>,
+}
+
+impl Config {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> {
+ match std::fs::read_to_string(file) {
+ Ok(s) => Ok(self.merge(toml::from_str(&s)?)),
+ Err(err) => {
+ tracing::warn!("Error reading config file: {err}");
+ Ok(self)
+ }
+ }
+ }
+
+ pub fn env(self, prefix: &str) -> Result<Self, Error> {
+ Ok(self.merge(Self {
+ listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR"))
+ .ok()
+ .and_then(|v| v.parse().ok()),
+ jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(),
+ jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(),
+ database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(),
+ }))
+ }
+
+ pub async fn build(self) -> Result<(TcpListener, Router), Error> {
+ macro_rules! try_extract {
+ ($($i:ident),+) => {
+ $(let Some($i) = self.$i else {
+ return Err(Error::Config(format!("Missing value: {}", stringify!($i))))
+ };)+
+ };
+ }
+
+ try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url);
+
+ let listener = TcpListener::bind(listen_addr).await?;
+ let pool = init_db(&database_url).await?;
+ let jwt_max_age = parse_duration(jwt_max_age)?;
+ let app_state = Arc::new(AppState {
+ pool,
+ jwt_secret,
+ jwt_max_age,
+ });
+ let app = unnamed_server::init_router(app_state);
+
+ Ok((listener, app))
+ }
+
+ /// Merge self with other, overwriting any existing values on self with other's.
+ fn merge(self, other: Self) -> Self {
+ Self {
+ listen_addr: other.listen_addr.or(self.listen_addr),
+ jwt_max_age: other.jwt_max_age.or(self.jwt_max_age),
+ jwt_secret: other.jwt_secret.or(self.jwt_secret),
+ database_url: other.database_url.or(self.database_url),
+ }
+ }
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))),
+ jwt_max_age: Some("1h".to_string()),
+ jwt_secret: None,
+ database_url: None,
+ }
+ }
+}
+
+async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> {
+ let pool = sqlx::postgres::PgPoolOptions::new()
+ .max_connections(10)
+ .connect(uri)
+ .await?;
+
+ sqlx::migrate!().run(&pool).await?;
+
+ Ok(pool)
+}
+
+fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> {
+ let chars = &mut s.as_ref().chars();
+ let mut nums: i64 = 0;
+ let mut unit = String::new();
+
+ for c in chars.by_ref() {
+ if c.is_ascii_digit() {
+ nums = nums * 10 + c.to_digit(10).unwrap() as i64;
+ } else {
+ unit.push(c);
+ break;
+ }
+ }
+
+ unit.extend(chars);
+
+ if "weeks".contains(&unit) {
+ Ok(Duration::weeks(nums))
+ } else if "days".contains(&unit) {
+ Ok(Duration::days(nums))
+ } else if "hours".contains(&unit) {
+ Ok(Duration::hours(nums))
+ } else if "minutes".contains(&unit) {
+ Ok(Duration::minutes(nums))
+ } else if "seconds".contains(&unit) {
+ Ok(Duration::seconds(nums))
+ } else {
+ Err(Error::Config("Invalid jwt_max_age".to_string()))
+ }
+}