From b94f8e694bf01f5dba9ce2c01f589463a3dfbc69 Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Wed, 9 Oct 2024 18:23:58 -0500 Subject: feat!: rewrite to use traits and streams --- src/api.rs | 57 ++-------------------- src/api/services.rs | 52 ++++++++++---------- src/api/sse.rs | 45 +++++++++++++++++ src/error.rs | 6 +++ src/lib.rs | 53 ++++---------------- src/main.rs | 39 +++++++-------- src/service.rs | 96 ++++++++++++++++++++----------------- src/service/command.rs | 128 +++++++++++++++++++++++++++++++------------------ src/service/http.rs | 54 +++++++++++++-------- src/service/systemd.rs | 36 -------------- src/service/tcp.rs | 75 +++++++++-------------------- src/sse.rs | 35 -------------- src/state.rs | 73 ++++++++++++++++++++++++++++ src/status.rs | 50 +++++++++++++++++++ 14 files changed, 419 insertions(+), 380 deletions(-) create mode 100644 src/api/sse.rs delete mode 100644 src/service/systemd.rs delete mode 100644 src/sse.rs create mode 100644 src/state.rs create mode 100644 src/status.rs (limited to 'src') diff --git a/src/api.rs b/src/api.rs index 5a8deb6..bab2043 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,57 +1,10 @@ -use std::collections::HashMap; - -use axum::{extract::State, response::IntoResponse, Json}; -use serde::{Deserialize, Serialize}; - -use crate::{service::ServiceHandles, Status}; +use crate::AppState; pub mod services; +pub mod sse; -pub fn router() -> axum::Router { - use axum::routing::get; - +pub fn router() -> axum::Router { axum::Router::new() - .route("/healthcheck", get(healthcheck)) - .merge(services::router()) -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct Health { - #[serde(flatten)] - pub status: Status, - pub checks: HashMap, -} - -impl From for Health { - fn from(value: T) -> Self { - Health { - status: value.into(), - ..Default::default() - } - } -} - -impl IntoResponse for Health { - fn into_response(self) -> axum::response::Response { - Json(self).into_response() - } -} - -pub async fn healthcheck(State(services): State) -> Health { - let checks = services - .iter() - .map(|(name, srv)| (name.clone(), srv.status())) - .collect::>(); - - let status = match checks - .values() - .filter(|s| !matches!(s, Status::Ok)) - .count() - { - 0 => Status::Ok, - 1 => Status::Error(Some("1 issue detected".to_string())), - n => Status::Error(Some(format!("{n} issues detected"))), - }; - - Health { status, checks } + .nest("/sse", sse::router()) + .nest("/status", services::router()) } diff --git a/src/api/services.rs b/src/api/services.rs index 132ecb1..aeca924 100644 --- a/src/api/services.rs +++ b/src/api/services.rs @@ -2,48 +2,50 @@ use std::collections::HashMap; use axum::{ extract::{Path, Query, State}, - Json, Router, + routing::get, + Json, }; -use axum_extra::routing::Resource; use serde::{Deserialize, Serialize}; -use crate::{service::ServiceHandles, Error, Status}; +use crate::{AppState, Error, Status}; + +pub fn router() -> axum::Router { + axum::Router::new() + .route("/", get(services)) + .route("/:name", get(service)) +} #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ServiceQuery { pub name: Option, - pub state: Option, + pub status: Option, } -pub fn router() -> Router { - Resource::named("services").index(index).show(show).into() +impl ServiceQuery { + pub fn filter(&self, value: &(String, Status)) -> bool { + !self.name.as_ref().is_some_and(|n| *n != value.0) + && !self.status.as_ref().is_some_and(|s| *s != value.1) + } } -pub async fn index( +pub async fn services( Query(query): Query, - State(services): State, + State(state): State, ) -> Json> { - let map = match query.name { - Some(n) => services - .iter() - .filter(|(name, _)| n == **name) - .map(|(name, srv)| (name.clone(), srv.status())) + Json( + state + .statuses() + .into_iter() + .filter(|item| query.filter(item)) .collect(), - None => services - .iter() - .map(|(name, srv)| (name.clone(), srv.status())) - .collect(), - }; - - Json(map) + ) } -pub async fn show( +pub async fn service( Path(name): Path, - State(services): State, + State(state): State, ) -> Result { - services - .get(&name) - .map(|s| s.status()) + state + .status(&name) .ok_or_else(|| Error::ServiceNotFound(name)) } diff --git a/src/api/sse.rs b/src/api/sse.rs new file mode 100644 index 0000000..5d913bb --- /dev/null +++ b/src/api/sse.rs @@ -0,0 +1,45 @@ +use std::convert::Infallible; + +use axum::{ + extract::{Path, State}, + response::{ + sse::{Event, KeepAlive}, + Sse, + }, + routing::get, + Router, +}; +use futures::Stream; +use tokio_stream::StreamExt; + +use crate::{AppState, Error}; + +pub fn router() -> Router { + axum::Router::new() + .route("/", get(events)) + .route("/:name", get(service_events)) +} + +pub async fn events( + State(state): State, +) -> Result>>, Error> { + let stream = state.streams().map(|(name, status)| { + let data = serde_json::to_string(&status)?; + Ok(Event::default().event(name).data(data)) + }); + + Ok(Sse::new(stream).keep_alive(KeepAlive::default())) +} + +pub async fn service_events( + Path(name): Path, + State(state): State, +) -> Result>>, Error> { + let stream = state + .stream(&name) + .ok_or_else(|| Error::ServiceNotFound(name))? + .map(Event::from) + .map(Ok); + + Ok(Sse::new(stream).keep_alive(KeepAlive::default())) +} diff --git a/src/error.rs b/src/error.rs index 109c944..8ed4dfa 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,12 +14,18 @@ pub enum Error { #[error("Invalid HTTP method")] Method, + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + #[error("Axum error: {0}")] Axum(#[from] axum::Error), #[error("Route not found: {0}")] RouteNotFound(axum::http::Uri), + #[error("Recv Error: {0}")] + Recv(#[from] tokio::sync::watch::error::RecvError), + #[error("Service not found: {0}")] ServiceNotFound(String), } diff --git a/src/lib.rs b/src/lib.rs index dc0efe7..6a64876 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,50 +1,15 @@ -use serde::{Deserialize, Serialize}; -use service::ServiceHandles; - -pub use crate::error::{Error, Result}; +pub use crate::{ + error::{Error, Result}, + state::AppState, + status::Status, +}; pub mod api; pub mod error; pub mod service; -pub mod sse; - -pub fn router() -> axum::Router { - axum::Router::new() - .nest("/api", api::router()) - .nest("/sse", sse::router()) -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -#[serde(rename_all = "lowercase", tag = "status", content = "output")] -pub enum Status { - Ok, - Error(Option), -} - -impl Default for Status { - fn default() -> Self { - Status::Error(Some("Unknown".to_string())) - } -} - -impl Status { - pub fn update(&mut self, status: Status) -> bool { - let modif = *self != status; - if modif { - *self = status; - } - modif - } -} - -impl From for Status { - fn from(value: T) -> Self { - Status::Error(Some(value.to_string())) - } -} +pub mod state; +pub mod status; -impl axum::response::IntoResponse for Status { - fn into_response(self) -> axum::response::Response { - axum::Json(self).into_response() - } +pub fn router() -> axum::Router { + axum::Router::new().nest("/api/v1", api::router()) } diff --git a/src/main.rs b/src/main.rs index fbf27cb..46adbfa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,10 @@ -use std::{collections::HashMap, fs::File, path::PathBuf, sync::Arc}; +use std::{collections::HashMap, fs::File, path::PathBuf}; + use tower_http::services::ServeDir; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; -use statsrv::service::Service; - -#[cfg(not(debug_assertions))] -const DEFAULT_CONFIG: &str = "/etc/statsrv.toml"; -#[cfg(debug_assertions)] -const DEFAULT_CONFIG: &str = "./config.toml"; +use statsrv::{service::ServiceConfig, AppState}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -20,20 +16,14 @@ async fn main() -> Result<(), Box> { let config = match Config::parse() { Ok(c) => c, Err(err) => { - tracing::debug!("Failed to read config file: `{err}`"); - tracing::debug!("Using default config values"); + tracing::error!("Failed to read config file, using defaults: `{err}`"); Default::default() } }; - let state = config - .services - .into_iter() - .map(|(name, service)| (name, service.into())) - .collect(); - + let state = AppState::spawn_services(config.services); let router = statsrv::router() - .with_state(Arc::new(state)) + .with_state(state) .nest_service("/", ServeDir::new(config.root)) .layer(tower_http::trace::TraceLayer::new_for_http()); @@ -48,17 +38,20 @@ async fn main() -> Result<(), Box> { pub struct Config { pub root: PathBuf, pub address: String, - pub services: HashMap, + pub services: HashMap, } impl Config { - fn parse() -> Result> { - let config_path = std::env::args().nth(1).unwrap_or_else(|| { - tracing::debug!("Falling back to default config location"); - DEFAULT_CONFIG.to_string() - }); + const DEFAULT_CONFIG: &str = "/etc/statsrv.toml"; - let config_file = File::open(&config_path)?; + fn parse() -> Result> { + let config_file = match std::env::args().nth(1) { + Some(p) => File::open(&p)?, + None => { + tracing::debug!("Falling back to default config location"); + File::open(Self::DEFAULT_CONFIG)? + } + }; let config_toml = std::io::read_to_string(config_file)?; toml::from_str(&config_toml).map_err(Into::into) } diff --git a/src/service.rs b/src/service.rs index 3e37503..b10385a 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,68 +1,76 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; -use futures::Stream; -use http::Http; +use futures::{StreamExt, TryStreamExt}; use serde::Deserialize; -use systemd::Systemd; -use tcp::Tcp; -use tokio::{ - sync::watch::{channel, Receiver, Sender}, - task::JoinHandle, -}; -use tokio_stream::wrappers::WatchStream; +use tokio_stream::{Stream, StreamMap}; -use crate::{Error, Status}; +use crate::Status; +pub mod command; pub mod http; -pub mod systemd; pub mod tcp; -pub mod command; -pub type ServiceHandles = Arc>; +pub type ServiceHandles = HashMap; + +pub trait IntoService { + type Error: std::error::Error + Sync + Send + Sized; -pub trait ServiceSpawner { - fn spawn( - self, - tx: Sender, - ) -> impl std::future::Future> + std::marker::Send + 'static; + fn into_service(self) -> impl Stream> + Send; } -#[derive(Debug)] -pub struct ServiceHandle { - pub handle: JoinHandle>, - pub rx: Receiver, +pub trait IntoServiceMap { + type Error: std::error::Error + Sync + Send + Sized; + + fn into_service_map(self) -> impl Stream)> + Send; } -impl ServiceHandle { - pub fn new(service: impl ServiceSpawner) -> Self { - let (tx, rx) = channel(Status::Error(None)); - let handle = tokio::spawn(service.spawn(tx)); - Self { handle, rx } - } +impl IntoServiceMap for T +where + T: IntoIterator, + V: IntoService, + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone + std::marker::Unpin + std::marker::Send, +{ + type Error = V::Error; - pub fn status(&self) -> Status { - self.rx.borrow().clone() + fn into_service_map(self) -> impl Stream)> + Send { + let mut map = StreamMap::new(); + for (name, srv) in self.into_iter() { + map.insert(name, Box::pin(srv.into_service())); + } + map } +} - pub fn into_stream(&self) -> impl Stream { - WatchStream::new(self.rx.clone()) - } +pub fn default_interval() -> std::time::Duration { + std::time::Duration::from_secs(5) } #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] -pub enum Service { - Http(Http), - Tcp(Tcp), - Systemd(Systemd), +pub enum ServiceConfig { + Http(http::Http), + Tcp(tcp::Tcp), + Command(command::Command), } -impl From for ServiceHandle { - fn from(value: Service) -> Self { - match value { - Service::Http(s) => ServiceHandle::new(s), - Service::Tcp(s) => ServiceHandle::new(s), - Service::Systemd(s) => ServiceHandle::new(s), +#[derive(Debug, thiserror::Error)] +pub enum ServiceError { + #[error(transparent)] + Http(#[from] http::Error), + #[error(transparent)] + Tcp(#[from] tcp::Error), + #[error(transparent)] + Command(#[from] command::Error), +} + +impl IntoService for ServiceConfig { + type Error = ServiceError; + + fn into_service(self) -> impl Stream> + Send { + match self { + ServiceConfig::Http(h) => h.into_service().map_err(ServiceError::from).boxed(), + ServiceConfig::Tcp(t) => t.into_service().map_err(ServiceError::from).boxed(), + ServiceConfig::Command(c) => c.into_service().map_err(ServiceError::from).boxed(), } } } diff --git a/src/service/command.rs b/src/service/command.rs index 41a79b3..3535ee2 100644 --- a/src/service/command.rs +++ b/src/service/command.rs @@ -1,79 +1,113 @@ use std::{process::Stdio, time::Duration}; +use async_stream::stream; +use futures::{Stream, StreamExt}; use serde::Deserialize; -use tokio::{ - io::{AsyncBufReadExt, BufReader}, - sync::watch::Sender, -}; +use tokio::io::{AsyncBufReadExt, BufReader}; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Exited with status code: {code}\n{stderr}")] + ExitCode { code: i32, stderr: String }, + #[error("Process terminated by signal")] + Signal, + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::error::Error), + #[error("{0}")] + Stderr(String), + #[error("{0}")] + Output(String), + #[error("Exited with status code: {0}")] + PersistExitCode(i32), + #[error("Failed to get stderr of child process")] + NoStderr, + #[error("Failed to get stdout of child process")] + NoStdout, +} #[derive(Debug, Clone, Deserialize)] pub struct Command { pub command: String, pub args: Vec, - pub interval: Option, + #[serde(default)] + pub persist: bool, + #[serde(default = "super::default_interval")] + pub interval: Duration, } impl Command { - async fn spawn_interval( + #[tracing::instrument] + fn persist( + mut interval: tokio::time::Interval, mut command: tokio::process::Command, - period: Duration, - tx: Sender, - ) -> Result<(), Error> { - let mut interval = tokio::time::interval(period); - loop { - interval.tick().await; + ) -> impl Stream> { + stream! { + loop { + interval.tick().await; + + let mut child = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; - let status = command.output().await.map_or_else(Into::into, |o| { - if o.status.success() { - Status::Ok - } else { - let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string(); - Status::Error(Some(format!("Service state: {}", stdout))) + let mut stdout_reader = + BufReader::new(child.stdout.take().ok_or(Error::NoStdout)?).lines(); + + while let Some(line) = stdout_reader.next_line().await? { + if "Ok" == line { + yield Ok(()); + } else { + yield Err(Error::Output(line)) + } } - }); - tx.send_if_modified(|s| s.update(status)); + match child.wait().await?.code() { + Some(0) => yield Ok(()), + Some(code) => yield Err(Error::PersistExitCode(code)), + None => yield Err(Error::Signal), + }; + } } } - async fn spawn_persist( + #[tracing::instrument] + fn interval( + mut interval: tokio::time::Interval, mut command: tokio::process::Command, - tx: Sender, - ) -> Result<(), Error> { - let mut child = command.stdout(Stdio::piped()).spawn()?; - let mut stdout = BufReader::new(child.stdout.take().unwrap()).lines(); - - while let Some(line) = stdout.next_line().await? { - let status: Status = serde_json::from_str(&line) - .unwrap_or_else(|err| Status::Error(Some(format!("Serialization error: {err}")))); - tx.send_if_modified(|s| s.update(status)); + ) -> impl Stream> { + stream! { + loop { + interval.tick().await; + let output = command.output().await?; + match output.status.code() { + Some(0) => yield Ok(()), + Some(code) => { + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + yield Err(Error::ExitCode { code, stderr }) + } + None => yield Err(Error::Signal), + } + } } - - let exit_status = child.wait().await?; - let status = match exit_status.code() { - Some(0) => Status::Ok, - Some(code) => Status::Error(Some(format!("Exited with status code: {code}"))), - None => Status::Error(Some("Process terminated by signal".to_string())), - }; - - tx.send_if_modified(|s| s.update(status)); - Ok(()) } } -impl ServiceSpawner for Command { - async fn spawn(self, tx: Sender) -> Result<(), Error> { +impl IntoService for Command { + type Error = Error; + + fn into_service(self) -> impl Stream> { + let interval = tokio::time::interval(self.interval); let mut command = tokio::process::Command::new(self.command); command.args(self.args); - if let Some(period) = self.interval { - Self::spawn_interval(command, period, tx).await + if self.persist { + Self::persist(interval, command).boxed() } else { - Self::spawn_persist(command, tx).await + Self::interval(interval, command).boxed() } } } diff --git a/src/service/http.rs b/src/service/http.rs index 8950096..c4fcee7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,13 +1,20 @@ use std::{fmt::Display, time::Duration}; +use async_stream::try_stream; use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; -use tokio::sync::watch::Sender; use url::Url; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Request error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("Bad status code: {0}")] + StatusCode(u16), +} #[derive(Debug, Clone, Deserialize)] pub struct Http { @@ -18,26 +25,31 @@ pub struct Http { pub status_code: StatusCode, #[serde(skip, default)] pub client: Option, + #[serde(default = "super::default_interval")] + pub interval: Duration, } -impl ServiceSpawner for Http { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let client = self.client.unwrap_or_default(); - let request = client.request(self.method.into(), self.url).build()?; - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - loop { - interval.tick().await; - let req = request - .try_clone() - .expect("Clone with no body should never fail"); - let resp = client.execute(req).await; - let status = resp.map_or_else(Into::into, |r| match r.status().as_u16() { - c if c == self.status_code => Status::Ok, - c => Status::Error(Some(format!("Status code: {c}"))), - }); - - tx.send_if_modified(|s| s.update(status)); +impl IntoService for Http { + type Error = Error; + + fn into_service(self) -> impl Stream> { + let mut interval = tokio::time::interval(self.interval); + + try_stream! { + let client = self.client.unwrap_or_default(); + let req = client.request(self.method.into(), self.url).build()?; + loop { + interval.tick().await; + let req = req + .try_clone() + .expect("Clone with no body should never fail"); + let status_code = client.execute(req).await?.status().as_u16(); + if status_code == self.status_code { + yield (); + } else { + Err(Error::StatusCode(status_code))? + } + } } } } diff --git a/src/service/systemd.rs b/src/service/systemd.rs deleted file mode 100644 index ee220b8..0000000 --- a/src/service/systemd.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::{process::Command, time::Duration}; - -use serde::Deserialize; -use tokio::sync::watch::Sender; - -use crate::{Error, Status}; - -use super::ServiceSpawner; - -#[derive(Debug, Clone, Deserialize)] -pub struct Systemd { - pub service: String, -} - -impl ServiceSpawner for Systemd { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let mut command = Command::new("systemctl"); - command.arg("is-active").arg(&self.service); - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - loop { - interval.tick().await; - - let status = command.output().map_or_else(Into::into, |o| { - if o.status.success() { - Status::Ok - } else { - let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string(); - Status::Error(Some(format!("Service state: {}", stdout))) - } - }); - - tx.send_if_modified(|s| s.update(status)); - } - } -} diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 7b79afd..6556af0 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -1,15 +1,19 @@ use std::{fmt::Display, net::SocketAddr, time::Duration}; +use async_stream::try_stream; +use futures::Stream; use serde::Deserialize; -use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; +use tokio::{io::Interest, net::TcpSocket}; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +pub(crate) type Error = std::io::Error; #[derive(Debug, Clone, Deserialize)] pub struct Tcp { pub address: SocketAddr, + #[serde(default = "super::default_interval")] + pub interval: Duration, } impl Display for Tcp { @@ -18,59 +22,24 @@ impl Display for Tcp { } } -impl ServiceSpawner for Tcp { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let mut interval = tokio::time::interval(Duration::from_secs(5)); +impl IntoService for Tcp { + type Error = Error; - loop { - interval.tick().await; + fn into_service(self) -> impl Stream> { + let mut interval = tokio::time::interval(self.interval); - let sock = TcpSocket::new_v4()?; - sock.set_keepalive(true)?; + try_stream! { + loop { + interval.tick().await; - match sock.connect(self.address).await { - Ok(conn) => { - // TODO: figure out how to wait for connection to close - conn.ready(Interest::READABLE).await?; - tx.send_if_modified(|s| s.update(Status::Ok)); - } - Err(err) => { - tx.send_if_modified(|s| s.update(err.into())); - } - }; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[tracing_test::traced_test] - #[ignore] - async fn test_tcp_watch() { - let (tx, mut rx) = tokio::sync::watch::channel(Status::Error(None)); + let sock = TcpSocket::new_v4()?; + sock.set_keepalive(true)?; - let tests = tokio::spawn(async move { - assert!(matches!(*rx.borrow_and_update(), Status::Error(None))); - - rx.changed().await.unwrap(); - assert!(matches!(*rx.borrow_and_update(), Status::Ok)); - - rx.changed().await.unwrap(); - assert_eq!( - *rx.borrow_and_update(), - Status::Error(Some(String::from("Disconnected"))) - ); - }); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - tokio::spawn(async move { Tcp { address }.spawn(tx).await }); - listener.accept().await.unwrap(); - drop(listener); - - tests.await.unwrap() + let conn = sock.connect(self.address).await?; + // TODO: figure out how to wait for connection to close + conn.ready(Interest::READABLE).await?; + yield (); + } + } } } diff --git a/src/sse.rs b/src/sse.rs deleted file mode 100644 index 88befd1..0000000 --- a/src/sse.rs +++ /dev/null @@ -1,35 +0,0 @@ -use axum::{ - extract::{Path, State}, - response::{ - sse::{Event, KeepAlive}, - Sse, - }, - routing::get, - Router, -}; -use futures::Stream; -use tokio_stream::StreamExt; - -use crate::{service::ServiceHandles, Error}; - -pub fn router() -> Router { - axum::Router::new().route("/:name", get(service_events)) -} - -pub async fn service_events( - Path(name): Path, - State(services): State, -) -> Result>>, Error> { - let stream = services - .get(&name) - .ok_or_else(|| Error::ServiceNotFound(name))? - .into_stream() - .map(|s| match s { - crate::Status::Ok => Event::default().event("ok"), - crate::Status::Error(None) => Event::default().event("error"), - crate::Status::Error(Some(msg)) => Event::default().event("error").data(msg), - }) - .map(Ok); - - Ok(Sse::new(stream).keep_alive(KeepAlive::default())) -} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..bbe235b --- /dev/null +++ b/src/state.rs @@ -0,0 +1,73 @@ +use std::collections::HashMap; + +use futures::{Stream, StreamExt}; +use tokio::sync::watch::Receiver; +use tokio_stream::wrappers::WatchStream; + +use crate::{ + service::{IntoService, ServiceConfig}, + Status, +}; + +#[derive(Clone)] +pub struct AppState { + rx_map: HashMap>, +} + +impl AppState { + pub fn spawn_services(configs: HashMap) -> AppState { + let mut rx_map = HashMap::new(); + let mut tx_map = HashMap::new(); + for name in configs.keys() { + let (tx, rx) = tokio::sync::watch::channel(Default::default()); + rx_map.insert(name.clone(), rx); + tx_map.insert(name.clone(), tx); + } + + for (name, config) in configs { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + rx_map.insert(name.clone(), rx); + tokio::spawn(async move { + let mut stream = config.into_service(); + while let Some(res) = stream.next().await { + let status = res.into(); + tx.send_if_modified(|s| { + if *s != status { + tracing::debug!(name, ?status, "Updated service status"); + *s = status; + true + } else { + false + } + }); + } + }); + } + + AppState { rx_map } + } + + pub fn status(&self, k: &str) -> Option { + self.rx_map.get(k).map(|rx| rx.borrow().clone()) + } + + pub fn statuses(&self) -> Vec<(String, Status)> { + self.rx_map + .iter() + .map(|(k, v)| (k.clone(), v.borrow().clone())) + .collect() + } + + pub fn stream(&self, k: &str) -> Option> { + self.rx_map.get(k).cloned().map(WatchStream::new) + } + + pub fn streams(&self) -> impl Stream { + let iter = + self.rx_map.clone().into_iter().map(|(name, rx)| { + WatchStream::new(rx).map(move |status| (name.to_owned(), status)) + }); + + futures::stream::select_all(iter) + } +} diff --git a/src/status.rs b/src/status.rs new file mode 100644 index 0000000..6e674c8 --- /dev/null +++ b/src/status.rs @@ -0,0 +1,50 @@ +use axum::response::sse::Event; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase", tag = "status", content = "output")] +pub enum Status { + Ok, + Error(Option), +} + +impl Default for Status { + fn default() -> Self { + Status::Error(None) + } +} + +impl Status { + pub fn update(&mut self, status: Status) -> bool { + let modif = *self != status; + if modif { + *self = status; + } + modif + } +} + +impl From> for Status { + fn from(value: Result) -> Self { + match value { + Ok(_) => Status::Ok, + Err(err) => Status::Error(Some(err.to_string())), + } + } +} + +impl axum::response::IntoResponse for Status { + fn into_response(self) -> axum::response::Response { + axum::Json(self).into_response() + } +} + +impl From for Event { + fn from(value: Status) -> Self { + match value { + Status::Ok => Event::default().event("ok"), + Status::Error(None) => Event::default().event("error"), + Status::Error(Some(msg)) => Event::default().event("error").data(msg), + } + } +} -- cgit v1.2.3-70-g09d2