summaryrefslogtreecommitdiffstats
path: root/src/config.rs
blob: d36b8fdfdd1e261ed4d93dfc7ac82ca6b142f61e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use std::{net::SocketAddr, sync::Arc};

use axum::Router;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use unnamed_server::{state::AppState, Error};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
    listen_addr: Option<SocketAddr>,
    jwt_secret: Option<String>,
    database_url: Option<String>,
}

impl Config {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> {
        match std::fs::read_to_string(file) {
            Ok(s) => Ok(self.merge(toml::from_str(&s)?)),
            Err(err) => {
                tracing::warn!("Error reading config file: {err}");
                Ok(self)
            }
        }
    }

    pub fn env(self, prefix: &str) -> Result<Self, Error> {
        Ok(self.merge(Self {
            listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR"))
                .ok()
                .and_then(|v| v.parse().ok()),
            jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(),
            database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(),
        }))
    }

    pub async fn build(self) -> Result<(TcpListener, Router), Error> {
        macro_rules! try_extract {
            ($($i:ident),+) => {
                $(let Some($i) = self.$i else {
                    return Err(Error::Config(format!("Missing value: {}", stringify!($i))))
                };)+
            };
        }

        try_extract!(listen_addr, jwt_secret, database_url);

        let listener = TcpListener::bind(listen_addr).await?;
        let pool = init_db(&database_url).await?;
        let app_state = Arc::new(AppState { pool, jwt_secret });
        let app = unnamed_server::init_router(app_state);

        Ok((listener, app))
    }

    /// Merge self with other, overwriting any existing values on self with other's.
    fn merge(self, other: Self) -> Self {
        Self {
            listen_addr: other.listen_addr.or(self.listen_addr),
            jwt_secret: other.jwt_secret.or(self.jwt_secret),
            database_url: other.database_url.or(self.database_url),
        }
    }
}

impl Default for Config {
    fn default() -> Self {
        Self {
            listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))),
            jwt_secret: None,
            database_url: None,
        }
    }
}

async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> {
    let pool = sqlx::postgres::PgPoolOptions::new()
        .max_connections(10)
        .connect(uri)
        .await?;

    sqlx::migrate!().run(&pool).await?;

    Ok(pool)
}