From cd774827dd14f68d8405c45d2d9da30b3fab050e Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Sat, 28 Sep 2024 00:54:46 -0500 Subject: feat: refactor into pub-sub and impl SSE --- src/api.rs | 14 ++-- src/api/services.rs | 33 +++++---- src/lib.rs | 36 ++++++---- src/main.rs | 20 ++++-- src/service.rs | 106 +++++++++------------------- src/service/http.rs | 183 ++++++++++++++++++++++++++++++++++++++++++------- src/service/systemd.rs | 31 ++++++++- src/service/tcp.rs | 79 +++++++++++++++++++-- src/sse.rs | 30 ++++++++ 9 files changed, 390 insertions(+), 142 deletions(-) create mode 100644 src/sse.rs (limited to 'src') diff --git a/src/api.rs b/src/api.rs index 8dfd2ca..1489c21 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,11 +3,11 @@ use std::collections::HashMap; use axum::{extract::State, response::IntoResponse, Json}; use serde::{Deserialize, Serialize}; -use crate::{service::Services, Status}; +use crate::{service::ServiceHandles, Status}; pub mod services; -pub fn router() -> axum::Router { +pub fn router() -> axum::Router { use axum::routing::get; axum::Router::new() @@ -37,11 +37,11 @@ impl IntoResponse for Health { } } -pub async fn healthcheck(State(services): State) -> Health { - let checks = match services.check().await { - Ok(c) => c, - Err(err) => return err.into(), - }; +pub async fn healthcheck(State(services): State) -> Health { + let checks = services + .iter() + .map(|(name, srv)| (name.clone(), srv.status())) + .collect::>(); let status = match checks .values() diff --git a/src/api/services.rs b/src/api/services.rs index 63018f1..132ecb1 100644 --- a/src/api/services.rs +++ b/src/api/services.rs @@ -7,7 +7,7 @@ use axum::{ use axum_extra::routing::Resource; use serde::{Deserialize, Serialize}; -use crate::{service::Services, Error, Status}; +use crate::{service::ServiceHandles, Error, Status}; #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ServiceQuery { @@ -15,26 +15,35 @@ pub struct ServiceQuery { pub state: Option, } -pub fn router() -> Router { +pub fn router() -> Router { Resource::named("services").index(index).show(show).into() } pub async fn index( Query(query): Query, - State(services): State, -) -> Result>, Error> { - services - .check_filtered(|name| (!query.name.as_ref().is_some_and(|s| s != name))) - .await - .map(Json) + State(services): State, +) -> Json> { + let map = match query.name { + Some(n) => services + .iter() + .filter(|(name, _)| n == **name) + .map(|(name, srv)| (name.clone(), srv.status())) + .collect(), + None => services + .iter() + .map(|(name, srv)| (name.clone(), srv.status())) + .collect(), + }; + + Json(map) } pub async fn show( Path(name): Path, - State(services): State, + State(services): State, ) -> Result { services - .check_one(&name) - .await - .ok_or_else(|| Error::ServiceNotFound(name))? + .get(&name) + .map(|s| s.status()) + .ok_or_else(|| Error::ServiceNotFound(name)) } diff --git a/src/lib.rs b/src/lib.rs index 1ccecf7..d24f635 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,37 @@ -use std::path::PathBuf; - use serde::{Deserialize, Serialize}; -use service::Services; -use tower_http::services::ServeDir; +use service::ServiceHandles; pub use crate::error::{Error, Result}; pub mod api; pub mod error; pub mod service; +pub mod sse; + +pub fn router() -> axum::Router { + axum::Router::new() + .nest("/api", api::router()) + .nest("/sse", sse::router()) +} -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[serde(rename_all = "lowercase", tag = "status", content = "output")] pub enum Status { - #[default] Pass, - Fail(Option), Warn(Option), + Fail(Option), + #[default] + Unknown, +} + +impl Status { + pub fn update(&mut self, status: Status) -> bool { + let modif = *self != status; + if modif { + *self = status; + } + modif + } } impl From for Status { @@ -30,10 +45,3 @@ impl axum::response::IntoResponse for Status { axum::Json(self).into_response() } } - -pub fn router(root: PathBuf) -> axum::Router { - axum::Router::new() - .nest_service("/", ServeDir::new(root)) - .nest("/api", api::router()) - .layer(tower_http::trace::TraceLayer::new_for_http()) -} diff --git a/src/main.rs b/src/main.rs index 97ed111..99af338 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,9 @@ -use std::{fs::File, path::PathBuf}; +use std::{collections::HashMap, fs::File, path::PathBuf, sync::Arc}; +use tower_http::services::ServeDir; use tracing::level_filters::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -use statsrv::service::Services; +use statsrv::service::Service; #[cfg(not(debug_assertions))] const DEFAULT_CONFIG: &str = "/etc/statsrv.toml"; @@ -29,7 +30,16 @@ async fn main() -> Result<(), Box> { } }; - let router = statsrv::router(config.root).with_state(config.services); + let state = config + .services + .into_iter() + .map(|(name, service)| (name, service.into())) + .collect(); + + let router = statsrv::router() + .with_state(Arc::new(state)) + .nest_service("/", ServeDir::new(config.root)) + .layer(tower_http::trace::TraceLayer::new_for_http()); let listener = tokio::net::TcpListener::bind(config.address).await.unwrap(); tracing::info!("listening on {}", listener.local_addr().unwrap()); @@ -42,7 +52,7 @@ async fn main() -> Result<(), Box> { pub struct Config { pub root: PathBuf, pub address: String, - pub services: Services, + pub services: HashMap, } impl Config { @@ -63,7 +73,7 @@ impl Default for Config { Self { root: PathBuf::from("./"), address: String::from("127.0.0.1:8080"), - services: Services::new(Default::default()), + services: Default::default(), } } } diff --git a/src/service.rs b/src/service.rs index bae6867..c45fcb1 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,10 +1,15 @@ -use std::{collections::HashMap, fmt::Display}; +use std::{collections::HashMap, sync::Arc}; -use futures::{stream::FuturesOrdered, TryStreamExt}; +use futures::Stream; use http::Http; use serde::Deserialize; use systemd::Systemd; use tcp::Tcp; +use tokio::{ + sync::watch::{Receiver, Sender}, + task::JoinHandle, +}; +use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; @@ -12,67 +17,34 @@ pub mod http; pub mod systemd; pub mod tcp; -#[derive(Debug, Clone, Deserialize)] -pub struct Services { - #[serde(flatten)] - inner: HashMap, - #[serde(skip, default = "Services::default_client")] - client: reqwest::Client, -} +pub type ServiceHandles = Arc>; -impl Services { - pub fn new(services: HashMap) -> Self { - let client = reqwest::Client::new(); - Self { - inner: services, - client, - } - } - - fn default_client() -> reqwest::Client { - reqwest::Client::new() - } +pub trait ServiceSpawner { + fn spawn( + self, + tx: Sender, + ) -> impl std::future::Future> + std::marker::Send + 'static; +} - pub async fn check(&self) -> Result, Error> { - let checks = self - .inner - .values() - .map(|service| service.check(self.client.clone())) - .collect::>() - .try_collect::>() - .await?; +#[derive(Debug)] +pub struct ServiceHandle { + pub handle: JoinHandle>, + pub rx: Receiver, +} - Ok(self - .inner - .keys() - .cloned() - .zip(checks) - .collect::>()) +impl ServiceHandle { + pub fn new(service: impl ServiceSpawner) -> Self { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + let handle = tokio::spawn(service.spawn(tx)); + Self { handle, rx } } - pub async fn check_one(&self, name: &str) -> Option> { - Some(self.inner.get(name)?.check(self.client.clone()).await) + pub fn status(&self) -> Status { + self.rx.borrow().clone() } - pub async fn check_filtered

(&self, mut predicate: P) -> Result, Error> - where - P: FnMut(&String) -> bool, - { - let checks = self - .inner - .iter() - .filter_map(|(s, service)| predicate(s).then_some(service)) - .map(|service| service.check(self.client.clone())) - .collect::>() - .try_collect::>() - .await?; - - Ok(self - .inner - .keys() - .cloned() - .zip(checks) - .collect::>()) + pub fn into_stream(&self) -> impl Stream { + WatchStream::new(self.rx.clone()) } } @@ -84,22 +56,12 @@ pub enum Service { Systemd(Systemd), } -impl Service { - pub async fn check(&self, client: reqwest::Client) -> Result { - match self { - Service::Http(http) => http.check(client).await, - Service::Tcp(tcp) => tcp.check().await, - Service::Systemd(systemd) => systemd.check().await, - } - } -} - -impl Display for Service { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Service::Http(http) => http.fmt(f), - Service::Tcp(tcp) => tcp.fmt(f), - Service::Systemd(systemd) => systemd.fmt(f), +impl From for ServiceHandle { + fn from(value: Service) -> Self { + match value { + Service::Http(s) => ServiceHandle::new(s), + Service::Tcp(s) => ServiceHandle::new(s), + Service::Systemd(s) => ServiceHandle::new(s), } } } diff --git a/src/service/http.rs b/src/service/http.rs index fb3ff13..6d21cb7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,46 +1,181 @@ -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; +use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; +use tokio::sync::watch::Sender; +use tokio_stream::wrappers::WatchStream; +use url::Url; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Http { - pub url: String, - #[serde(default = "Http::default_method")] - pub method: String, - #[serde(default = "Http::default_code")] - pub status_code: u16, + pub url: Url, + #[serde(default)] + pub method: Method, + #[serde(default, with = "status_code")] + pub status_code: StatusCode, + #[serde(skip, default)] + pub client: Option, +} + +impl Display for Http { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.method, self.url) + } +} + +impl ServiceSpawner for Http { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let client = self.client.unwrap_or_default(); + let request = client.request(self.method.into(), self.url).build()?; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } } impl Http { - fn default_method() -> String { - "GET".to_string() + pub fn into_stream(self, client: reqwest::Client) -> impl Stream { + let request = client + .request(self.method.into(), self.url) + .build() + .expect("Url parsing should not fail"); + + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + }); + + WatchStream::new(rx) } +} - fn default_code() -> u16 { - 200 +#[derive(Debug, Clone, Copy, Default, Deserialize)] +pub enum Method { + #[serde(alias = "get", alias = "GET")] + #[default] + Get, + #[serde(alias = "post", alias = "POST")] + Post, +} + +impl From for reqwest::Method { + fn from(value: Method) -> Self { + match value { + Method::Get => reqwest::Method::GET, + Method::Post => reqwest::Method::POST, + } } } -impl Display for Http { +impl Display for Method { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.method, self.url) + match self { + Method::Get => write!(f, "GET"), + Method::Post => write!(f, "POST"), + } } } -impl Http { - pub async fn check(&self, client: reqwest::Client) -> Result { - let status_code = client - .request(self.method.parse().map_err(|_| Error::Method)?, &self.url) - .send() - .await? - .status() - .as_u16(); - - match status_code == self.status_code { - true => Ok(Status::Pass), - false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))), +pub mod status_code { + use axum::http::StatusCode; + use serde::{ + de::{self, Unexpected, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + /// Implementation detail. Use derive annotations instead. + #[inline] + pub fn serialize(status: &StatusCode, ser: S) -> Result { + ser.serialize_u16(status.as_u16()) + } + + pub(crate) struct StatusVisitor; + + impl StatusVisitor { + #[inline(never)] + fn make(&self, val: u64) -> Result { + if (100..1000).contains(&val) { + if let Ok(s) = StatusCode::from_u16(val as u16) { + return Ok(s); + } + } + Err(de::Error::invalid_value(Unexpected::Unsigned(val), self)) + } + } + + impl<'de> Visitor<'de> for StatusVisitor { + type Value = StatusCode; + + #[inline] + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("status code") + } + + #[inline] + fn visit_some>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_u16(self) + } + + #[inline] + fn visit_i64(self, val: i64) -> Result { + self.make(val as _) + } + + #[inline] + fn visit_u64(self, val: u64) -> Result { + self.make(val) } } + + /// Implementation detail. + #[inline] + pub fn deserialize<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + de.deserialize_u16(StatusVisitor) + } } diff --git a/src/service/systemd.rs b/src/service/systemd.rs index 45f3bf9..e3b4d1b 100644 --- a/src/service/systemd.rs +++ b/src/service/systemd.rs @@ -1,9 +1,12 @@ -use std::{fmt::Display, process::Command}; +use std::{fmt::Display, process::Command, time::Duration}; use serde::Deserialize; +use tokio::sync::watch::Sender; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Systemd { pub service: String, @@ -15,6 +18,32 @@ impl Display for Systemd { } } +impl ServiceSpawner for Systemd { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut command = Command::new("systemctl"); + command.arg("is-active").arg(&self.service); + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + + let status = match command.output() { + Ok(output) if output.status.success() => Status::Pass, + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + Status::Fail(Some(format!("Service state: {}", stdout))) + } + Err(err) => { + tracing::error!("Failed to spawn process: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } +} + impl Systemd { pub async fn check(&self) -> Result { let output = Command::new("systemctl") diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 87e696a..5ec5f36 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -1,12 +1,17 @@ -use std::fmt::Display; +use std::{fmt::Display, net::SocketAddr, time::Duration}; +use futures::Stream; use serde::Deserialize; +use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; +use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Tcp { - pub address: String, + pub address: SocketAddr, } impl Display for Tcp { @@ -15,11 +20,71 @@ impl Display for Tcp { } } +impl ServiceSpawner for Tcp { + #[tracing::instrument(skip(tx))] + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + + loop { + interval.tick().await; + + let sock = TcpSocket::new_v4()?; + sock.set_keepalive(true)?; + + match sock.connect(self.address).await { + Ok(conn) => { + tracing::info!("Connected"); + tx.send_if_modified(|s| s.update(Status::Pass)); + conn.ready(Interest::ERROR).await?; + tx.send_replace(Status::Fail(Some("Disconnected".into()))); + tracing::info!("Disconnected"); + } + Err(err) => { + tracing::error!("Failed to connect"); + tx.send_if_modified(|s| s.update(err.into())); + } + }; + } + } +} + impl Tcp { - pub async fn check(&self) -> Result { - Ok(std::net::TcpStream::connect(&self.address) - .err() - .map(Into::into) - .unwrap_or_default()) + pub fn into_stream(self) -> impl Stream { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + tokio::spawn(self.spawn(tx)); + WatchStream::new(rx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[tracing_test::traced_test] + #[ignore] + async fn test_tcp_watch() { + let (tx, mut rx) = tokio::sync::watch::channel(Status::default()); + + let tests = tokio::spawn(async move { + assert!(matches!(*rx.borrow_and_update(), Status::Unknown)); + + rx.changed().await.unwrap(); + assert!(matches!(*rx.borrow_and_update(), Status::Pass)); + + rx.changed().await.unwrap(); + assert_eq!( + *rx.borrow_and_update(), + Status::Fail(Some(String::from("Disconnected"))) + ); + }); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + tokio::spawn(async move { Tcp { address }.spawn(tx).await }); + listener.accept().await.unwrap(); + drop(listener); + + tests.await.unwrap() } } diff --git a/src/sse.rs b/src/sse.rs new file mode 100644 index 0000000..b4a8840 --- /dev/null +++ b/src/sse.rs @@ -0,0 +1,30 @@ +use axum::{ + extract::{Path, State}, + response::{ + sse::{Event, KeepAlive}, + Sse, + }, + routing::get, + Router, +}; +use futures::Stream; +use tokio_stream::StreamExt; + +use crate::{service::ServiceHandles, Error}; + +pub fn router() -> Router { + axum::Router::new().route("/:name", get(sse_handler)) +} + +pub async fn sse_handler( + Path(name): Path, + State(services): State, +) -> Result>>, Error> { + let stream = services + .get(&name) + .ok_or_else(|| Error::ServiceNotFound(name))? + .into_stream() + .map(|s| Event::default().json_data(s)); + + Ok(Sse::new(stream).keep_alive(KeepAlive::default())) +} -- cgit v1.2.3-70-g09d2