summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:54:46 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:58:45 -0500
commitcd774827dd14f68d8405c45d2d9da30b3fab050e (patch)
treea24e1cabb99170caa25edff53fc978111a1c9dd4 /src
parent04c7f7609e5bc3fadf95c53b37a9e6e12c4e539c (diff)
feat: refactor into pub-sub and impl SSE
Diffstat (limited to 'src')
-rw-r--r--src/api.rs14
-rw-r--r--src/api/services.rs33
-rw-r--r--src/lib.rs36
-rw-r--r--src/main.rs20
-rw-r--r--src/service.rs106
-rw-r--r--src/service/http.rs183
-rw-r--r--src/service/systemd.rs31
-rw-r--r--src/service/tcp.rs79
-rw-r--r--src/sse.rs30
9 files changed, 390 insertions, 142 deletions
diff --git a/src/api.rs b/src/api.rs
index 8dfd2ca..1489c21 100644
--- a/src/api.rs
+++ b/src/api.rs
@@ -3,11 +3,11 @@ use std::collections::HashMap;
use axum::{extract::State, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
-use crate::{service::Services, Status};
+use crate::{service::ServiceHandles, Status};
pub mod services;
-pub fn router() -> axum::Router<Services> {
+pub fn router() -> axum::Router<ServiceHandles> {
use axum::routing::get;
axum::Router::new()
@@ -37,11 +37,11 @@ impl IntoResponse for Health {
}
}
-pub async fn healthcheck(State(services): State<Services>) -> Health {
- let checks = match services.check().await {
- Ok(c) => c,
- Err(err) => return err.into(),
- };
+pub async fn healthcheck(State(services): State<ServiceHandles>) -> Health {
+ let checks = services
+ .iter()
+ .map(|(name, srv)| (name.clone(), srv.status()))
+ .collect::<HashMap<String, Status>>();
let status = match checks
.values()
diff --git a/src/api/services.rs b/src/api/services.rs
index 63018f1..132ecb1 100644
--- a/src/api/services.rs
+++ b/src/api/services.rs
@@ -7,7 +7,7 @@ use axum::{
use axum_extra::routing::Resource;
use serde::{Deserialize, Serialize};
-use crate::{service::Services, Error, Status};
+use crate::{service::ServiceHandles, Error, Status};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ServiceQuery {
@@ -15,26 +15,35 @@ pub struct ServiceQuery {
pub state: Option<Status>,
}
-pub fn router() -> Router<Services> {
+pub fn router() -> Router<ServiceHandles> {
Resource::named("services").index(index).show(show).into()
}
pub async fn index(
Query(query): Query<ServiceQuery>,
- State(services): State<Services>,
-) -> Result<Json<HashMap<String, Status>>, Error> {
- services
- .check_filtered(|name| (!query.name.as_ref().is_some_and(|s| s != name)))
- .await
- .map(Json)
+ State(services): State<ServiceHandles>,
+) -> Json<HashMap<String, Status>> {
+ let map = match query.name {
+ Some(n) => services
+ .iter()
+ .filter(|(name, _)| n == **name)
+ .map(|(name, srv)| (name.clone(), srv.status()))
+ .collect(),
+ None => services
+ .iter()
+ .map(|(name, srv)| (name.clone(), srv.status()))
+ .collect(),
+ };
+
+ Json(map)
}
pub async fn show(
Path(name): Path<String>,
- State(services): State<Services>,
+ State(services): State<ServiceHandles>,
) -> Result<Status, Error> {
services
- .check_one(&name)
- .await
- .ok_or_else(|| Error::ServiceNotFound(name))?
+ .get(&name)
+ .map(|s| s.status())
+ .ok_or_else(|| Error::ServiceNotFound(name))
}
diff --git a/src/lib.rs b/src/lib.rs
index 1ccecf7..d24f635 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,22 +1,37 @@
-use std::path::PathBuf;
-
use serde::{Deserialize, Serialize};
-use service::Services;
-use tower_http::services::ServeDir;
+use service::ServiceHandles;
pub use crate::error::{Error, Result};
pub mod api;
pub mod error;
pub mod service;
+pub mod sse;
+
+pub fn router() -> axum::Router<ServiceHandles> {
+ axum::Router::new()
+ .nest("/api", api::router())
+ .nest("/sse", sse::router())
+}
-#[derive(Debug, Clone, Default, Serialize, Deserialize)]
+#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "lowercase", tag = "status", content = "output")]
pub enum Status {
- #[default]
Pass,
- Fail(Option<String>),
Warn(Option<String>),
+ Fail(Option<String>),
+ #[default]
+ Unknown,
+}
+
+impl Status {
+ pub fn update(&mut self, status: Status) -> bool {
+ let modif = *self != status;
+ if modif {
+ *self = status;
+ }
+ modif
+ }
}
impl<T: std::error::Error> From<T> for Status {
@@ -30,10 +45,3 @@ impl axum::response::IntoResponse for Status {
axum::Json(self).into_response()
}
}
-
-pub fn router(root: PathBuf) -> axum::Router<Services> {
- axum::Router::new()
- .nest_service("/", ServeDir::new(root))
- .nest("/api", api::router())
- .layer(tower_http::trace::TraceLayer::new_for_http())
-}
diff --git a/src/main.rs b/src/main.rs
index 97ed111..99af338 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,8 +1,9 @@
-use std::{fs::File, path::PathBuf};
+use std::{collections::HashMap, fs::File, path::PathBuf, sync::Arc};
+use tower_http::services::ServeDir;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
-use statsrv::service::Services;
+use statsrv::service::Service;
#[cfg(not(debug_assertions))]
const DEFAULT_CONFIG: &str = "/etc/statsrv.toml";
@@ -29,7 +30,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
};
- let router = statsrv::router(config.root).with_state(config.services);
+ let state = config
+ .services
+ .into_iter()
+ .map(|(name, service)| (name, service.into()))
+ .collect();
+
+ let router = statsrv::router()
+ .with_state(Arc::new(state))
+ .nest_service("/", ServeDir::new(config.root))
+ .layer(tower_http::trace::TraceLayer::new_for_http());
let listener = tokio::net::TcpListener::bind(config.address).await.unwrap();
tracing::info!("listening on {}", listener.local_addr().unwrap());
@@ -42,7 +52,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
pub struct Config {
pub root: PathBuf,
pub address: String,
- pub services: Services,
+ pub services: HashMap<String, Service>,
}
impl Config {
@@ -63,7 +73,7 @@ impl Default for Config {
Self {
root: PathBuf::from("./"),
address: String::from("127.0.0.1:8080"),
- services: Services::new(Default::default()),
+ services: Default::default(),
}
}
}
diff --git a/src/service.rs b/src/service.rs
index bae6867..c45fcb1 100644
--- a/src/service.rs
+++ b/src/service.rs
@@ -1,10 +1,15 @@
-use std::{collections::HashMap, fmt::Display};
+use std::{collections::HashMap, sync::Arc};
-use futures::{stream::FuturesOrdered, TryStreamExt};
+use futures::Stream;
use http::Http;
use serde::Deserialize;
use systemd::Systemd;
use tcp::Tcp;
+use tokio::{
+ sync::watch::{Receiver, Sender},
+ task::JoinHandle,
+};
+use tokio_stream::wrappers::WatchStream;
use crate::{Error, Status};
@@ -12,67 +17,34 @@ pub mod http;
pub mod systemd;
pub mod tcp;
-#[derive(Debug, Clone, Deserialize)]
-pub struct Services {
- #[serde(flatten)]
- inner: HashMap<String, Service>,
- #[serde(skip, default = "Services::default_client")]
- client: reqwest::Client,
-}
+pub type ServiceHandles = Arc<HashMap<String, ServiceHandle>>;
-impl Services {
- pub fn new(services: HashMap<String, Service>) -> Self {
- let client = reqwest::Client::new();
- Self {
- inner: services,
- client,
- }
- }
-
- fn default_client() -> reqwest::Client {
- reqwest::Client::new()
- }
+pub trait ServiceSpawner {
+ fn spawn(
+ self,
+ tx: Sender<Status>,
+ ) -> impl std::future::Future<Output = Result<(), Error>> + std::marker::Send + 'static;
+}
- pub async fn check(&self) -> Result<HashMap<String, Status>, Error> {
- let checks = self
- .inner
- .values()
- .map(|service| service.check(self.client.clone()))
- .collect::<FuturesOrdered<_>>()
- .try_collect::<Vec<_>>()
- .await?;
+#[derive(Debug)]
+pub struct ServiceHandle {
+ pub handle: JoinHandle<Result<(), Error>>,
+ pub rx: Receiver<Status>,
+}
- Ok(self
- .inner
- .keys()
- .cloned()
- .zip(checks)
- .collect::<HashMap<_, _>>())
+impl ServiceHandle {
+ pub fn new(service: impl ServiceSpawner) -> Self {
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+ let handle = tokio::spawn(service.spawn(tx));
+ Self { handle, rx }
}
- pub async fn check_one(&self, name: &str) -> Option<Result<Status, Error>> {
- Some(self.inner.get(name)?.check(self.client.clone()).await)
+ pub fn status(&self) -> Status {
+ self.rx.borrow().clone()
}
- pub async fn check_filtered<P>(&self, mut predicate: P) -> Result<HashMap<String, Status>, Error>
- where
- P: FnMut(&String) -> bool,
- {
- let checks = self
- .inner
- .iter()
- .filter_map(|(s, service)| predicate(s).then_some(service))
- .map(|service| service.check(self.client.clone()))
- .collect::<FuturesOrdered<_>>()
- .try_collect::<Vec<_>>()
- .await?;
-
- Ok(self
- .inner
- .keys()
- .cloned()
- .zip(checks)
- .collect::<HashMap<_, _>>())
+ pub fn into_stream(&self) -> impl Stream<Item = Status> {
+ WatchStream::new(self.rx.clone())
}
}
@@ -84,22 +56,12 @@ pub enum Service {
Systemd(Systemd),
}
-impl Service {
- pub async fn check(&self, client: reqwest::Client) -> Result<Status, Error> {
- match self {
- Service::Http(http) => http.check(client).await,
- Service::Tcp(tcp) => tcp.check().await,
- Service::Systemd(systemd) => systemd.check().await,
- }
- }
-}
-
-impl Display for Service {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Service::Http(http) => http.fmt(f),
- Service::Tcp(tcp) => tcp.fmt(f),
- Service::Systemd(systemd) => systemd.fmt(f),
+impl From<Service> for ServiceHandle {
+ fn from(value: Service) -> Self {
+ match value {
+ Service::Http(s) => ServiceHandle::new(s),
+ Service::Tcp(s) => ServiceHandle::new(s),
+ Service::Systemd(s) => ServiceHandle::new(s),
}
}
}
diff --git a/src/service/http.rs b/src/service/http.rs
index fb3ff13..6d21cb7 100644
--- a/src/service/http.rs
+++ b/src/service/http.rs
@@ -1,46 +1,181 @@
-use std::fmt::Display;
+use std::{fmt::Display, time::Duration};
+use axum::http::status::StatusCode;
+use futures::Stream;
use serde::Deserialize;
+use tokio::sync::watch::Sender;
+use tokio_stream::wrappers::WatchStream;
+use url::Url;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Http {
- pub url: String,
- #[serde(default = "Http::default_method")]
- pub method: String,
- #[serde(default = "Http::default_code")]
- pub status_code: u16,
+ pub url: Url,
+ #[serde(default)]
+ pub method: Method,
+ #[serde(default, with = "status_code")]
+ pub status_code: StatusCode,
+ #[serde(skip, default)]
+ pub client: Option<reqwest::Client>,
+}
+
+impl Display for Http {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{} {}", self.method, self.url)
+ }
+}
+
+impl ServiceSpawner for Http {
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let client = self.client.unwrap_or_default();
+ let request = client.request(self.method.into(), self.url).build()?;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ }
}
impl Http {
- fn default_method() -> String {
- "GET".to_string()
+ pub fn into_stream(self, client: reqwest::Client) -> impl Stream<Item = Status> {
+ let request = client
+ .request(self.method.into(), self.url)
+ .build()
+ .expect("Url parsing should not fail");
+
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+
+ tokio::spawn(async move {
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ });
+
+ WatchStream::new(rx)
}
+}
- fn default_code() -> u16 {
- 200
+#[derive(Debug, Clone, Copy, Default, Deserialize)]
+pub enum Method {
+ #[serde(alias = "get", alias = "GET")]
+ #[default]
+ Get,
+ #[serde(alias = "post", alias = "POST")]
+ Post,
+}
+
+impl From<Method> for reqwest::Method {
+ fn from(value: Method) -> Self {
+ match value {
+ Method::Get => reqwest::Method::GET,
+ Method::Post => reqwest::Method::POST,
+ }
}
}
-impl Display for Http {
+impl Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{} {}", self.method, self.url)
+ match self {
+ Method::Get => write!(f, "GET"),
+ Method::Post => write!(f, "POST"),
+ }
}
}
-impl Http {
- pub async fn check(&self, client: reqwest::Client) -> Result<Status, Error> {
- let status_code = client
- .request(self.method.parse().map_err(|_| Error::Method)?, &self.url)
- .send()
- .await?
- .status()
- .as_u16();
-
- match status_code == self.status_code {
- true => Ok(Status::Pass),
- false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))),
+pub mod status_code {
+ use axum::http::StatusCode;
+ use serde::{
+ de::{self, Unexpected, Visitor},
+ Deserializer, Serializer,
+ };
+ use std::fmt;
+
+ /// Implementation detail. Use derive annotations instead.
+ #[inline]
+ pub fn serialize<S: Serializer>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
+ ser.serialize_u16(status.as_u16())
+ }
+
+ pub(crate) struct StatusVisitor;
+
+ impl StatusVisitor {
+ #[inline(never)]
+ fn make<E: de::Error>(&self, val: u64) -> Result<StatusCode, E> {
+ if (100..1000).contains(&val) {
+ if let Ok(s) = StatusCode::from_u16(val as u16) {
+ return Ok(s);
+ }
+ }
+ Err(de::Error::invalid_value(Unexpected::Unsigned(val), self))
+ }
+ }
+
+ impl<'de> Visitor<'de> for StatusVisitor {
+ type Value = StatusCode;
+
+ #[inline]
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("status code")
+ }
+
+ #[inline]
+ fn visit_some<D: Deserializer<'de>>(
+ self,
+ deserializer: D,
+ ) -> Result<Self::Value, D::Error> {
+ deserializer.deserialize_u16(self)
+ }
+
+ #[inline]
+ fn visit_i64<E: de::Error>(self, val: i64) -> Result<Self::Value, E> {
+ self.make(val as _)
+ }
+
+ #[inline]
+ fn visit_u64<E: de::Error>(self, val: u64) -> Result<Self::Value, E> {
+ self.make(val)
}
}
+
+ /// Implementation detail.
+ #[inline]
+ pub fn deserialize<'de, D>(de: D) -> Result<StatusCode, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ de.deserialize_u16(StatusVisitor)
+ }
}
diff --git a/src/service/systemd.rs b/src/service/systemd.rs
index 45f3bf9..e3b4d1b 100644
--- a/src/service/systemd.rs
+++ b/src/service/systemd.rs
@@ -1,9 +1,12 @@
-use std::{fmt::Display, process::Command};
+use std::{fmt::Display, process::Command, time::Duration};
use serde::Deserialize;
+use tokio::sync::watch::Sender;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Systemd {
pub service: String,
@@ -15,6 +18,32 @@ impl Display for Systemd {
}
}
+impl ServiceSpawner for Systemd {
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let mut command = Command::new("systemctl");
+ command.arg("is-active").arg(&self.service);
+
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+
+ let status = match command.output() {
+ Ok(output) if output.status.success() => Status::Pass,
+ Ok(output) => {
+ let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
+ Status::Fail(Some(format!("Service state: {}", stdout)))
+ }
+ Err(err) => {
+ tracing::error!("Failed to spawn process: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ }
+}
+
impl Systemd {
pub async fn check(&self) -> Result<Status, Error> {
let output = Command::new("systemctl")
diff --git a/src/service/tcp.rs b/src/service/tcp.rs
index 87e696a..5ec5f36 100644
--- a/src/service/tcp.rs
+++ b/src/service/tcp.rs
@@ -1,12 +1,17 @@
-use std::fmt::Display;
+use std::{fmt::Display, net::SocketAddr, time::Duration};
+use futures::Stream;
use serde::Deserialize;
+use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender};
+use tokio_stream::wrappers::WatchStream;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Tcp {
- pub address: String,
+ pub address: SocketAddr,
}
impl Display for Tcp {
@@ -15,11 +20,71 @@ impl Display for Tcp {
}
}
+impl ServiceSpawner for Tcp {
+ #[tracing::instrument(skip(tx))]
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+
+ loop {
+ interval.tick().await;
+
+ let sock = TcpSocket::new_v4()?;
+ sock.set_keepalive(true)?;
+
+ match sock.connect(self.address).await {
+ Ok(conn) => {
+ tracing::info!("Connected");
+ tx.send_if_modified(|s| s.update(Status::Pass));
+ conn.ready(Interest::ERROR).await?;
+ tx.send_replace(Status::Fail(Some("Disconnected".into())));
+ tracing::info!("Disconnected");
+ }
+ Err(err) => {
+ tracing::error!("Failed to connect");
+ tx.send_if_modified(|s| s.update(err.into()));
+ }
+ };
+ }
+ }
+}
+
impl Tcp {
- pub async fn check(&self) -> Result<Status, Error> {
- Ok(std::net::TcpStream::connect(&self.address)
- .err()
- .map(Into::into)
- .unwrap_or_default())
+ pub fn into_stream(self) -> impl Stream<Item = Status> {
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+ tokio::spawn(self.spawn(tx));
+ WatchStream::new(rx)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[tokio::test]
+ #[tracing_test::traced_test]
+ #[ignore]
+ async fn test_tcp_watch() {
+ let (tx, mut rx) = tokio::sync::watch::channel(Status::default());
+
+ let tests = tokio::spawn(async move {
+ assert!(matches!(*rx.borrow_and_update(), Status::Unknown));
+
+ rx.changed().await.unwrap();
+ assert!(matches!(*rx.borrow_and_update(), Status::Pass));
+
+ rx.changed().await.unwrap();
+ assert_eq!(
+ *rx.borrow_and_update(),
+ Status::Fail(Some(String::from("Disconnected")))
+ );
+ });
+
+ let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let address = listener.local_addr().unwrap();
+ tokio::spawn(async move { Tcp { address }.spawn(tx).await });
+ listener.accept().await.unwrap();
+ drop(listener);
+
+ tests.await.unwrap()
}
}
diff --git a/src/sse.rs b/src/sse.rs
new file mode 100644
index 0000000..b4a8840
--- /dev/null
+++ b/src/sse.rs
@@ -0,0 +1,30 @@
+use axum::{
+ extract::{Path, State},
+ response::{
+ sse::{Event, KeepAlive},
+ Sse,
+ },
+ routing::get,
+ Router,
+};
+use futures::Stream;
+use tokio_stream::StreamExt;
+
+use crate::{service::ServiceHandles, Error};
+
+pub fn router() -> Router<ServiceHandles> {
+ axum::Router::new().route("/:name", get(sse_handler))
+}
+
+pub async fn sse_handler(
+ Path(name): Path<String>,
+ State(services): State<ServiceHandles>,
+) -> Result<Sse<impl Stream<Item = Result<Event, axum::Error>>>, Error> {
+ let stream = services
+ .get(&name)
+ .ok_or_else(|| Error::ServiceNotFound(name))?
+ .into_stream()
+ .map(|s| Event::default().json_data(s));
+
+ Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
+}