From cd774827dd14f68d8405c45d2d9da30b3fab050e Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Sat, 28 Sep 2024 00:54:46 -0500 Subject: feat: refactor into pub-sub and impl SSE --- Cargo.lock | 60 ++++++++++++++++ Cargo.toml | 8 ++- assets/index.js | 92 ++++++++++--------------- src/api.rs | 14 ++-- src/api/services.rs | 33 +++++---- src/lib.rs | 36 ++++++---- src/main.rs | 20 ++++-- src/service.rs | 106 +++++++++------------------- src/service/http.rs | 183 ++++++++++++++++++++++++++++++++++++++++++------- src/service/systemd.rs | 31 ++++++++- src/service/tcp.rs | 79 +++++++++++++++++++-- src/sse.rs | 30 ++++++++ 12 files changed, 494 insertions(+), 198 deletions(-) create mode 100644 src/sse.rs diff --git a/Cargo.lock b/Cargo.lock index db02f18..b072987 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -1101,6 +1123,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" name = "statsrv" version = "0.1.0" dependencies = [ + "async-stream", "axum", "axum-extra", "futures", @@ -1112,10 +1135,13 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-stream", "toml", "tower-http", "tracing", "tracing-subscriber", + "tracing-test", + "url", ] [[package]] @@ -1277,6 +1303,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.7.12" @@ -1439,6 +1477,27 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tracing-test" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -1490,6 +1549,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 69f7eeb..1cb7655 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +async-stream = "0.3.5" axum = "0.7.6" axum-extra = "0.9.4" futures = "0.3.30" @@ -14,8 +15,13 @@ reqwest = { version = "0.12.7", features = ["blocking"] } serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.128" thiserror = "1.0.63" -tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "net"] } +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "net", "time"] } +tokio-stream = { version = "0.1.16", features = ["sync"] } toml = "0.8.19" tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +url = { version = "2.5.2", features = ["serde"] } + +[dev-dependencies] +tracing-test = "0.2.5" diff --git a/assets/index.js b/assets/index.js index 246364f..e65369c 100644 --- a/assets/index.js +++ b/assets/index.js @@ -1,16 +1,7 @@ -/** - * @typedef {Object} Check - * @property {String} status - 'pass'|'fail'|'warn' - * @property {String} output - Details. Not present if 'pass' - */ +const serviceMap = new Map(); -/** - * @typedef {Check} HealthCheck - * @property {Map} checks - */ - -async function getHealthCheck() { - const url = "api/healthcheck"; +async function getServices() { + const url = "api/services"; try { const response = await fetch(url); if (!response.ok) { @@ -24,64 +15,55 @@ async function getHealthCheck() { } } -function updateStatus(check) { +function updateStatus() { const statusElm = document.getElementById("status"); const issuesElm = document.getElementById("issues"); - switch (check.status) { + const issues = [...serviceMap.values()].filter((s) => !s).length; + issuesElm.textContent = `${issues} issue(s) detected`; + if (issues) { + statusElm.setAttribute("class", "error"); + } else { + statusElm.setAttribute("class", "ok"); + } +} + +function updateService(name, node, status) { + switch (status.status) { case "pass": - issuesElm.textContent = "No issues detected"; - statusElm.setAttribute("class", "ok"); + node.textContent = "Operational"; + node.setAttribute("class", "ok"); break; case "fail": - issuesElm.textContent = check.output; - statusElm.setAttribute("class", "error"); + node.textContent = "Down"; + node.title = status.output; + node.setAttribute("class", "error"); break; case "warn": - issuesElm.textContent = check.output; - statusElm.setAttribute("class", "warning"); + node.textContent = "Warning"; + node.title = status.output; + node.setAttribute("class", "warning"); break; - default: - issuesElm.textContent = "Unknown"; - statusElm.setAttribute("class", "warning"); + case "unknown": + node.textContent = "Unknown"; + node.setAttribute("class", "warning"); } -} -getHealthCheck().then((healthCheck) => { - const table = document.getElementById("services"); - const evtSource = new EventSource("sse"); - updateStatus(healthCheck); + serviceMap.set(name, status.status === "pass"); + updateStatus(); +} - for (const [service, check] of Object.entries(healthCheck.checks)) { +getServices().then((services) => { + for (const [service] of Object.entries(services)) { + const table = document.getElementById("services"); const row = table.insertRow(); - const nameNode = row.insertCell(); nameNode.textContent = service; + const node = row.insertCell(); - const stateNode = row.insertCell(); - switch (check.status) { - case "pass": - stateNode.textContent = "Operational"; - stateNode.setAttribute("class", "ok"); - break; - case "fail": - stateNode.textContent = "Down"; - stateNode.title = check.output; - stateNode.setAttribute("class", "error"); - break; - case "warn": - stateNode.textContent = "Warning"; - stateNode.title = check.output; - stateNode.setAttribute("class", "warning"); - break; - default: - stateNode.textContent = "Unknown"; - statusElm.setAttribute("class", "warning"); - } - - evtSource.addEventListener(service, (event) => { + const evtSource = new EventSource(`sse/${service}`); + evtSource.onmessage = (event) => { const status = JSON.parse(event.data); - stateNode.textContent = status.state; - stateNode.title = status.output; - }); + updateService(service, node, status); + }; } }); diff --git a/src/api.rs b/src/api.rs index 8dfd2ca..1489c21 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,11 +3,11 @@ use std::collections::HashMap; use axum::{extract::State, response::IntoResponse, Json}; use serde::{Deserialize, Serialize}; -use crate::{service::Services, Status}; +use crate::{service::ServiceHandles, Status}; pub mod services; -pub fn router() -> axum::Router { +pub fn router() -> axum::Router { use axum::routing::get; axum::Router::new() @@ -37,11 +37,11 @@ impl IntoResponse for Health { } } -pub async fn healthcheck(State(services): State) -> Health { - let checks = match services.check().await { - Ok(c) => c, - Err(err) => return err.into(), - }; +pub async fn healthcheck(State(services): State) -> Health { + let checks = services + .iter() + .map(|(name, srv)| (name.clone(), srv.status())) + .collect::>(); let status = match checks .values() diff --git a/src/api/services.rs b/src/api/services.rs index 63018f1..132ecb1 100644 --- a/src/api/services.rs +++ b/src/api/services.rs @@ -7,7 +7,7 @@ use axum::{ use axum_extra::routing::Resource; use serde::{Deserialize, Serialize}; -use crate::{service::Services, Error, Status}; +use crate::{service::ServiceHandles, Error, Status}; #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ServiceQuery { @@ -15,26 +15,35 @@ pub struct ServiceQuery { pub state: Option, } -pub fn router() -> Router { +pub fn router() -> Router { Resource::named("services").index(index).show(show).into() } pub async fn index( Query(query): Query, - State(services): State, -) -> Result>, Error> { - services - .check_filtered(|name| (!query.name.as_ref().is_some_and(|s| s != name))) - .await - .map(Json) + State(services): State, +) -> Json> { + let map = match query.name { + Some(n) => services + .iter() + .filter(|(name, _)| n == **name) + .map(|(name, srv)| (name.clone(), srv.status())) + .collect(), + None => services + .iter() + .map(|(name, srv)| (name.clone(), srv.status())) + .collect(), + }; + + Json(map) } pub async fn show( Path(name): Path, - State(services): State, + State(services): State, ) -> Result { services - .check_one(&name) - .await - .ok_or_else(|| Error::ServiceNotFound(name))? + .get(&name) + .map(|s| s.status()) + .ok_or_else(|| Error::ServiceNotFound(name)) } diff --git a/src/lib.rs b/src/lib.rs index 1ccecf7..d24f635 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,37 @@ -use std::path::PathBuf; - use serde::{Deserialize, Serialize}; -use service::Services; -use tower_http::services::ServeDir; +use service::ServiceHandles; pub use crate::error::{Error, Result}; pub mod api; pub mod error; pub mod service; +pub mod sse; + +pub fn router() -> axum::Router { + axum::Router::new() + .nest("/api", api::router()) + .nest("/sse", sse::router()) +} -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[serde(rename_all = "lowercase", tag = "status", content = "output")] pub enum Status { - #[default] Pass, - Fail(Option), Warn(Option), + Fail(Option), + #[default] + Unknown, +} + +impl Status { + pub fn update(&mut self, status: Status) -> bool { + let modif = *self != status; + if modif { + *self = status; + } + modif + } } impl From for Status { @@ -30,10 +45,3 @@ impl axum::response::IntoResponse for Status { axum::Json(self).into_response() } } - -pub fn router(root: PathBuf) -> axum::Router { - axum::Router::new() - .nest_service("/", ServeDir::new(root)) - .nest("/api", api::router()) - .layer(tower_http::trace::TraceLayer::new_for_http()) -} diff --git a/src/main.rs b/src/main.rs index 97ed111..99af338 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,9 @@ -use std::{fs::File, path::PathBuf}; +use std::{collections::HashMap, fs::File, path::PathBuf, sync::Arc}; +use tower_http::services::ServeDir; use tracing::level_filters::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -use statsrv::service::Services; +use statsrv::service::Service; #[cfg(not(debug_assertions))] const DEFAULT_CONFIG: &str = "/etc/statsrv.toml"; @@ -29,7 +30,16 @@ async fn main() -> Result<(), Box> { } }; - let router = statsrv::router(config.root).with_state(config.services); + let state = config + .services + .into_iter() + .map(|(name, service)| (name, service.into())) + .collect(); + + let router = statsrv::router() + .with_state(Arc::new(state)) + .nest_service("/", ServeDir::new(config.root)) + .layer(tower_http::trace::TraceLayer::new_for_http()); let listener = tokio::net::TcpListener::bind(config.address).await.unwrap(); tracing::info!("listening on {}", listener.local_addr().unwrap()); @@ -42,7 +52,7 @@ async fn main() -> Result<(), Box> { pub struct Config { pub root: PathBuf, pub address: String, - pub services: Services, + pub services: HashMap, } impl Config { @@ -63,7 +73,7 @@ impl Default for Config { Self { root: PathBuf::from("./"), address: String::from("127.0.0.1:8080"), - services: Services::new(Default::default()), + services: Default::default(), } } } diff --git a/src/service.rs b/src/service.rs index bae6867..c45fcb1 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,10 +1,15 @@ -use std::{collections::HashMap, fmt::Display}; +use std::{collections::HashMap, sync::Arc}; -use futures::{stream::FuturesOrdered, TryStreamExt}; +use futures::Stream; use http::Http; use serde::Deserialize; use systemd::Systemd; use tcp::Tcp; +use tokio::{ + sync::watch::{Receiver, Sender}, + task::JoinHandle, +}; +use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; @@ -12,67 +17,34 @@ pub mod http; pub mod systemd; pub mod tcp; -#[derive(Debug, Clone, Deserialize)] -pub struct Services { - #[serde(flatten)] - inner: HashMap, - #[serde(skip, default = "Services::default_client")] - client: reqwest::Client, -} +pub type ServiceHandles = Arc>; -impl Services { - pub fn new(services: HashMap) -> Self { - let client = reqwest::Client::new(); - Self { - inner: services, - client, - } - } - - fn default_client() -> reqwest::Client { - reqwest::Client::new() - } +pub trait ServiceSpawner { + fn spawn( + self, + tx: Sender, + ) -> impl std::future::Future> + std::marker::Send + 'static; +} - pub async fn check(&self) -> Result, Error> { - let checks = self - .inner - .values() - .map(|service| service.check(self.client.clone())) - .collect::>() - .try_collect::>() - .await?; +#[derive(Debug)] +pub struct ServiceHandle { + pub handle: JoinHandle>, + pub rx: Receiver, +} - Ok(self - .inner - .keys() - .cloned() - .zip(checks) - .collect::>()) +impl ServiceHandle { + pub fn new(service: impl ServiceSpawner) -> Self { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + let handle = tokio::spawn(service.spawn(tx)); + Self { handle, rx } } - pub async fn check_one(&self, name: &str) -> Option> { - Some(self.inner.get(name)?.check(self.client.clone()).await) + pub fn status(&self) -> Status { + self.rx.borrow().clone() } - pub async fn check_filtered

(&self, mut predicate: P) -> Result, Error> - where - P: FnMut(&String) -> bool, - { - let checks = self - .inner - .iter() - .filter_map(|(s, service)| predicate(s).then_some(service)) - .map(|service| service.check(self.client.clone())) - .collect::>() - .try_collect::>() - .await?; - - Ok(self - .inner - .keys() - .cloned() - .zip(checks) - .collect::>()) + pub fn into_stream(&self) -> impl Stream { + WatchStream::new(self.rx.clone()) } } @@ -84,22 +56,12 @@ pub enum Service { Systemd(Systemd), } -impl Service { - pub async fn check(&self, client: reqwest::Client) -> Result { - match self { - Service::Http(http) => http.check(client).await, - Service::Tcp(tcp) => tcp.check().await, - Service::Systemd(systemd) => systemd.check().await, - } - } -} - -impl Display for Service { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Service::Http(http) => http.fmt(f), - Service::Tcp(tcp) => tcp.fmt(f), - Service::Systemd(systemd) => systemd.fmt(f), +impl From for ServiceHandle { + fn from(value: Service) -> Self { + match value { + Service::Http(s) => ServiceHandle::new(s), + Service::Tcp(s) => ServiceHandle::new(s), + Service::Systemd(s) => ServiceHandle::new(s), } } } diff --git a/src/service/http.rs b/src/service/http.rs index fb3ff13..6d21cb7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,46 +1,181 @@ -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; +use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; +use tokio::sync::watch::Sender; +use tokio_stream::wrappers::WatchStream; +use url::Url; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Http { - pub url: String, - #[serde(default = "Http::default_method")] - pub method: String, - #[serde(default = "Http::default_code")] - pub status_code: u16, + pub url: Url, + #[serde(default)] + pub method: Method, + #[serde(default, with = "status_code")] + pub status_code: StatusCode, + #[serde(skip, default)] + pub client: Option, +} + +impl Display for Http { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.method, self.url) + } +} + +impl ServiceSpawner for Http { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let client = self.client.unwrap_or_default(); + let request = client.request(self.method.into(), self.url).build()?; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } } impl Http { - fn default_method() -> String { - "GET".to_string() + pub fn into_stream(self, client: reqwest::Client) -> impl Stream { + let request = client + .request(self.method.into(), self.url) + .build() + .expect("Url parsing should not fail"); + + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + }); + + WatchStream::new(rx) } +} - fn default_code() -> u16 { - 200 +#[derive(Debug, Clone, Copy, Default, Deserialize)] +pub enum Method { + #[serde(alias = "get", alias = "GET")] + #[default] + Get, + #[serde(alias = "post", alias = "POST")] + Post, +} + +impl From for reqwest::Method { + fn from(value: Method) -> Self { + match value { + Method::Get => reqwest::Method::GET, + Method::Post => reqwest::Method::POST, + } } } -impl Display for Http { +impl Display for Method { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.method, self.url) + match self { + Method::Get => write!(f, "GET"), + Method::Post => write!(f, "POST"), + } } } -impl Http { - pub async fn check(&self, client: reqwest::Client) -> Result { - let status_code = client - .request(self.method.parse().map_err(|_| Error::Method)?, &self.url) - .send() - .await? - .status() - .as_u16(); - - match status_code == self.status_code { - true => Ok(Status::Pass), - false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))), +pub mod status_code { + use axum::http::StatusCode; + use serde::{ + de::{self, Unexpected, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + /// Implementation detail. Use derive annotations instead. + #[inline] + pub fn serialize(status: &StatusCode, ser: S) -> Result { + ser.serialize_u16(status.as_u16()) + } + + pub(crate) struct StatusVisitor; + + impl StatusVisitor { + #[inline(never)] + fn make(&self, val: u64) -> Result { + if (100..1000).contains(&val) { + if let Ok(s) = StatusCode::from_u16(val as u16) { + return Ok(s); + } + } + Err(de::Error::invalid_value(Unexpected::Unsigned(val), self)) + } + } + + impl<'de> Visitor<'de> for StatusVisitor { + type Value = StatusCode; + + #[inline] + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("status code") + } + + #[inline] + fn visit_some>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_u16(self) + } + + #[inline] + fn visit_i64(self, val: i64) -> Result { + self.make(val as _) + } + + #[inline] + fn visit_u64(self, val: u64) -> Result { + self.make(val) } } + + /// Implementation detail. + #[inline] + pub fn deserialize<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + de.deserialize_u16(StatusVisitor) + } } diff --git a/src/service/systemd.rs b/src/service/systemd.rs index 45f3bf9..e3b4d1b 100644 --- a/src/service/systemd.rs +++ b/src/service/systemd.rs @@ -1,9 +1,12 @@ -use std::{fmt::Display, process::Command}; +use std::{fmt::Display, process::Command, time::Duration}; use serde::Deserialize; +use tokio::sync::watch::Sender; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Systemd { pub service: String, @@ -15,6 +18,32 @@ impl Display for Systemd { } } +impl ServiceSpawner for Systemd { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut command = Command::new("systemctl"); + command.arg("is-active").arg(&self.service); + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + + let status = match command.output() { + Ok(output) if output.status.success() => Status::Pass, + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + Status::Fail(Some(format!("Service state: {}", stdout))) + } + Err(err) => { + tracing::error!("Failed to spawn process: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } +} + impl Systemd { pub async fn check(&self) -> Result { let output = Command::new("systemctl") diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 87e696a..5ec5f36 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -1,12 +1,17 @@ -use std::fmt::Display; +use std::{fmt::Display, net::SocketAddr, time::Duration}; +use futures::Stream; use serde::Deserialize; +use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; +use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Tcp { - pub address: String, + pub address: SocketAddr, } impl Display for Tcp { @@ -15,11 +20,71 @@ impl Display for Tcp { } } +impl ServiceSpawner for Tcp { + #[tracing::instrument(skip(tx))] + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + + loop { + interval.tick().await; + + let sock = TcpSocket::new_v4()?; + sock.set_keepalive(true)?; + + match sock.connect(self.address).await { + Ok(conn) => { + tracing::info!("Connected"); + tx.send_if_modified(|s| s.update(Status::Pass)); + conn.ready(Interest::ERROR).await?; + tx.send_replace(Status::Fail(Some("Disconnected".into()))); + tracing::info!("Disconnected"); + } + Err(err) => { + tracing::error!("Failed to connect"); + tx.send_if_modified(|s| s.update(err.into())); + } + }; + } + } +} + impl Tcp { - pub async fn check(&self) -> Result { - Ok(std::net::TcpStream::connect(&self.address) - .err() - .map(Into::into) - .unwrap_or_default()) + pub fn into_stream(self) -> impl Stream { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + tokio::spawn(self.spawn(tx)); + WatchStream::new(rx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[tracing_test::traced_test] + #[ignore] + async fn test_tcp_watch() { + let (tx, mut rx) = tokio::sync::watch::channel(Status::default()); + + let tests = tokio::spawn(async move { + assert!(matches!(*rx.borrow_and_update(), Status::Unknown)); + + rx.changed().await.unwrap(); + assert!(matches!(*rx.borrow_and_update(), Status::Pass)); + + rx.changed().await.unwrap(); + assert_eq!( + *rx.borrow_and_update(), + Status::Fail(Some(String::from("Disconnected"))) + ); + }); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + tokio::spawn(async move { Tcp { address }.spawn(tx).await }); + listener.accept().await.unwrap(); + drop(listener); + + tests.await.unwrap() } } diff --git a/src/sse.rs b/src/sse.rs new file mode 100644 index 0000000..b4a8840 --- /dev/null +++ b/src/sse.rs @@ -0,0 +1,30 @@ +use axum::{ + extract::{Path, State}, + response::{ + sse::{Event, KeepAlive}, + Sse, + }, + routing::get, + Router, +}; +use futures::Stream; +use tokio_stream::StreamExt; + +use crate::{service::ServiceHandles, Error}; + +pub fn router() -> Router { + axum::Router::new().route("/:name", get(sse_handler)) +} + +pub async fn sse_handler( + Path(name): Path, + State(services): State, +) -> Result>>, Error> { + let stream = services + .get(&name) + .ok_or_else(|| Error::ServiceNotFound(name))? + .into_stream() + .map(|s| Event::default().json_data(s)); + + Ok(Sse::new(stream).keep_alive(KeepAlive::default())) +} -- cgit v1.2.3-70-g09d2