From cd774827dd14f68d8405c45d2d9da30b3fab050e Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Sat, 28 Sep 2024 00:54:46 -0500 Subject: feat: refactor into pub-sub and impl SSE --- src/service/http.rs | 183 ++++++++++++++++++++++++++++++++++++++++++------- src/service/systemd.rs | 31 ++++++++- src/service/tcp.rs | 79 +++++++++++++++++++-- 3 files changed, 261 insertions(+), 32 deletions(-) (limited to 'src/service') diff --git a/src/service/http.rs b/src/service/http.rs index fb3ff13..6d21cb7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,46 +1,181 @@ -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; +use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; +use tokio::sync::watch::Sender; +use tokio_stream::wrappers::WatchStream; +use url::Url; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Http { - pub url: String, - #[serde(default = "Http::default_method")] - pub method: String, - #[serde(default = "Http::default_code")] - pub status_code: u16, + pub url: Url, + #[serde(default)] + pub method: Method, + #[serde(default, with = "status_code")] + pub status_code: StatusCode, + #[serde(skip, default)] + pub client: Option, +} + +impl Display for Http { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.method, self.url) + } +} + +impl ServiceSpawner for Http { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let client = self.client.unwrap_or_default(); + let request = client.request(self.method.into(), self.url).build()?; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } } impl Http { - fn default_method() -> String { - "GET".to_string() + pub fn into_stream(self, client: reqwest::Client) -> impl Stream { + let request = client + .request(self.method.into(), self.url) + .build() + .expect("Url parsing should not fail"); + + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + }); + + WatchStream::new(rx) } +} - fn default_code() -> u16 { - 200 +#[derive(Debug, Clone, Copy, Default, Deserialize)] +pub enum Method { + #[serde(alias = "get", alias = "GET")] + #[default] + Get, + #[serde(alias = "post", alias = "POST")] + Post, +} + +impl From for reqwest::Method { + fn from(value: Method) -> Self { + match value { + Method::Get => reqwest::Method::GET, + Method::Post => reqwest::Method::POST, + } } } -impl Display for Http { +impl Display for Method { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.method, self.url) + match self { + Method::Get => write!(f, "GET"), + Method::Post => write!(f, "POST"), + } } } -impl Http { - pub async fn check(&self, client: reqwest::Client) -> Result { - let status_code = client - .request(self.method.parse().map_err(|_| Error::Method)?, &self.url) - .send() - .await? - .status() - .as_u16(); - - match status_code == self.status_code { - true => Ok(Status::Pass), - false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))), +pub mod status_code { + use axum::http::StatusCode; + use serde::{ + de::{self, Unexpected, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + /// Implementation detail. Use derive annotations instead. + #[inline] + pub fn serialize(status: &StatusCode, ser: S) -> Result { + ser.serialize_u16(status.as_u16()) + } + + pub(crate) struct StatusVisitor; + + impl StatusVisitor { + #[inline(never)] + fn make(&self, val: u64) -> Result { + if (100..1000).contains(&val) { + if let Ok(s) = StatusCode::from_u16(val as u16) { + return Ok(s); + } + } + Err(de::Error::invalid_value(Unexpected::Unsigned(val), self)) + } + } + + impl<'de> Visitor<'de> for StatusVisitor { + type Value = StatusCode; + + #[inline] + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("status code") + } + + #[inline] + fn visit_some>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_u16(self) + } + + #[inline] + fn visit_i64(self, val: i64) -> Result { + self.make(val as _) + } + + #[inline] + fn visit_u64(self, val: u64) -> Result { + self.make(val) } } + + /// Implementation detail. + #[inline] + pub fn deserialize<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + de.deserialize_u16(StatusVisitor) + } } diff --git a/src/service/systemd.rs b/src/service/systemd.rs index 45f3bf9..e3b4d1b 100644 --- a/src/service/systemd.rs +++ b/src/service/systemd.rs @@ -1,9 +1,12 @@ -use std::{fmt::Display, process::Command}; +use std::{fmt::Display, process::Command, time::Duration}; use serde::Deserialize; +use tokio::sync::watch::Sender; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Systemd { pub service: String, @@ -15,6 +18,32 @@ impl Display for Systemd { } } +impl ServiceSpawner for Systemd { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut command = Command::new("systemctl"); + command.arg("is-active").arg(&self.service); + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + + let status = match command.output() { + Ok(output) if output.status.success() => Status::Pass, + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + Status::Fail(Some(format!("Service state: {}", stdout))) + } + Err(err) => { + tracing::error!("Failed to spawn process: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } +} + impl Systemd { pub async fn check(&self) -> Result { let output = Command::new("systemctl") diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 87e696a..5ec5f36 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -1,12 +1,17 @@ -use std::fmt::Display; +use std::{fmt::Display, net::SocketAddr, time::Duration}; +use futures::Stream; use serde::Deserialize; +use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; +use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Tcp { - pub address: String, + pub address: SocketAddr, } impl Display for Tcp { @@ -15,11 +20,71 @@ impl Display for Tcp { } } +impl ServiceSpawner for Tcp { + #[tracing::instrument(skip(tx))] + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + + loop { + interval.tick().await; + + let sock = TcpSocket::new_v4()?; + sock.set_keepalive(true)?; + + match sock.connect(self.address).await { + Ok(conn) => { + tracing::info!("Connected"); + tx.send_if_modified(|s| s.update(Status::Pass)); + conn.ready(Interest::ERROR).await?; + tx.send_replace(Status::Fail(Some("Disconnected".into()))); + tracing::info!("Disconnected"); + } + Err(err) => { + tracing::error!("Failed to connect"); + tx.send_if_modified(|s| s.update(err.into())); + } + }; + } + } +} + impl Tcp { - pub async fn check(&self) -> Result { - Ok(std::net::TcpStream::connect(&self.address) - .err() - .map(Into::into) - .unwrap_or_default()) + pub fn into_stream(self) -> impl Stream { + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + tokio::spawn(self.spawn(tx)); + WatchStream::new(rx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[tracing_test::traced_test] + #[ignore] + async fn test_tcp_watch() { + let (tx, mut rx) = tokio::sync::watch::channel(Status::default()); + + let tests = tokio::spawn(async move { + assert!(matches!(*rx.borrow_and_update(), Status::Unknown)); + + rx.changed().await.unwrap(); + assert!(matches!(*rx.borrow_and_update(), Status::Pass)); + + rx.changed().await.unwrap(); + assert_eq!( + *rx.borrow_and_update(), + Status::Fail(Some(String::from("Disconnected"))) + ); + }); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + tokio::spawn(async move { Tcp { address }.spawn(tx).await }); + listener.accept().await.unwrap(); + drop(listener); + + tests.await.unwrap() } } -- cgit v1.2.3-70-g09d2