1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
use std::{net::SocketAddr, sync::Arc};
use axum::Router;
use serde::{Deserialize, Serialize};
use time::Duration;
use tokio::net::TcpListener;
use unnamed_server::{state::AppState, Error};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
listen_addr: Option<SocketAddr>,
jwt_max_age: Option<String>,
jwt_secret: Option<String>,
database_url: Option<String>,
}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn file<P: AsRef<std::path::Path>>(self, file: P) -> Result<Self, Error> {
match std::fs::read_to_string(file) {
Ok(s) => Ok(self.merge(toml::from_str(&s)?)),
Err(err) => {
tracing::warn!("Error reading config file: {err}");
Ok(self)
}
}
}
pub fn env(self, prefix: &str) -> Result<Self, Error> {
Ok(self.merge(Self {
listen_addr: std::env::var(format!("{prefix}LISTEN_ADDR"))
.ok()
.and_then(|v| v.parse().ok()),
jwt_max_age: std::env::var(format!("{prefix}JWT_MAX_AGE")).ok(),
jwt_secret: std::env::var(format!("{prefix}JWT_SECRET")).ok(),
database_url: std::env::var(format!("{prefix}DATABASE_URL")).ok(),
}))
}
pub async fn build(self) -> Result<(TcpListener, Router), Error> {
macro_rules! try_extract {
($($i:ident),+) => {
$(let Some($i) = self.$i else {
return Err(Error::Config(format!("Missing value: {}", stringify!($i))))
};)+
};
}
try_extract!(listen_addr, jwt_secret, jwt_max_age, database_url);
let listener = TcpListener::bind(listen_addr).await?;
let pool = init_db(&database_url).await?;
let jwt_max_age = parse_duration(jwt_max_age)?;
let app_state = Arc::new(AppState {
pool,
jwt_secret,
jwt_max_age,
});
let app = unnamed_server::init_router(app_state);
Ok((listener, app))
}
/// Merge self with other, overwriting any existing values on self with other's.
fn merge(self, other: Self) -> Self {
Self {
listen_addr: other.listen_addr.or(self.listen_addr),
jwt_max_age: other.jwt_max_age.or(self.jwt_max_age),
jwt_secret: other.jwt_secret.or(self.jwt_secret),
database_url: other.database_url.or(self.database_url),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
listen_addr: Some(SocketAddr::from(([127, 0, 0, 1], 30000))),
jwt_max_age: Some("1h".to_string()),
jwt_secret: None,
database_url: None,
}
}
}
async fn init_db(uri: &str) -> Result<sqlx::Pool<sqlx::Postgres>, Error> {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(10)
.connect(uri)
.await?;
sqlx::migrate!().run(&pool).await?;
Ok(pool)
}
fn parse_duration<S: AsRef<str>>(s: S) -> Result<Duration, Error> {
let chars = &mut s.as_ref().chars();
let mut nums: i64 = 0;
let mut unit = String::new();
for c in chars.by_ref() {
if c.is_ascii_digit() {
nums = nums * 10 + c.to_digit(10).unwrap() as i64;
} else {
unit.push(c);
break;
}
}
unit.extend(chars);
if "weeks".contains(&unit) {
Ok(Duration::weeks(nums))
} else if "days".contains(&unit) {
Ok(Duration::days(nums))
} else if "hours".contains(&unit) {
Ok(Duration::hours(nums))
} else if "minutes".contains(&unit) {
Ok(Duration::minutes(nums))
} else if "seconds".contains(&unit) {
Ok(Duration::seconds(nums))
} else {
Err(Error::Config("Invalid jwt_max_age".to_string()))
}
}
|