summaryrefslogtreecommitdiffstats
path: root/src/service
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:54:46 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:58:45 -0500
commitcd774827dd14f68d8405c45d2d9da30b3fab050e (patch)
treea24e1cabb99170caa25edff53fc978111a1c9dd4 /src/service
parent04c7f7609e5bc3fadf95c53b37a9e6e12c4e539c (diff)
feat: refactor into pub-sub and impl SSE
Diffstat (limited to 'src/service')
-rw-r--r--src/service/http.rs183
-rw-r--r--src/service/systemd.rs31
-rw-r--r--src/service/tcp.rs79
3 files changed, 261 insertions, 32 deletions
diff --git a/src/service/http.rs b/src/service/http.rs
index fb3ff13..6d21cb7 100644
--- a/src/service/http.rs
+++ b/src/service/http.rs
@@ -1,46 +1,181 @@
-use std::fmt::Display;
+use std::{fmt::Display, time::Duration};
+use axum::http::status::StatusCode;
+use futures::Stream;
use serde::Deserialize;
+use tokio::sync::watch::Sender;
+use tokio_stream::wrappers::WatchStream;
+use url::Url;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Http {
- pub url: String,
- #[serde(default = "Http::default_method")]
- pub method: String,
- #[serde(default = "Http::default_code")]
- pub status_code: u16,
+ pub url: Url,
+ #[serde(default)]
+ pub method: Method,
+ #[serde(default, with = "status_code")]
+ pub status_code: StatusCode,
+ #[serde(skip, default)]
+ pub client: Option<reqwest::Client>,
+}
+
+impl Display for Http {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{} {}", self.method, self.url)
+ }
+}
+
+impl ServiceSpawner for Http {
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let client = self.client.unwrap_or_default();
+ let request = client.request(self.method.into(), self.url).build()?;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ }
}
impl Http {
- fn default_method() -> String {
- "GET".to_string()
+ pub fn into_stream(self, client: reqwest::Client) -> impl Stream<Item = Status> {
+ let request = client
+ .request(self.method.into(), self.url)
+ .build()
+ .expect("Url parsing should not fail");
+
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+
+ tokio::spawn(async move {
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ });
+
+ WatchStream::new(rx)
}
+}
- fn default_code() -> u16 {
- 200
+#[derive(Debug, Clone, Copy, Default, Deserialize)]
+pub enum Method {
+ #[serde(alias = "get", alias = "GET")]
+ #[default]
+ Get,
+ #[serde(alias = "post", alias = "POST")]
+ Post,
+}
+
+impl From<Method> for reqwest::Method {
+ fn from(value: Method) -> Self {
+ match value {
+ Method::Get => reqwest::Method::GET,
+ Method::Post => reqwest::Method::POST,
+ }
}
}
-impl Display for Http {
+impl Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{} {}", self.method, self.url)
+ match self {
+ Method::Get => write!(f, "GET"),
+ Method::Post => write!(f, "POST"),
+ }
}
}
-impl Http {
- pub async fn check(&self, client: reqwest::Client) -> Result<Status, Error> {
- let status_code = client
- .request(self.method.parse().map_err(|_| Error::Method)?, &self.url)
- .send()
- .await?
- .status()
- .as_u16();
-
- match status_code == self.status_code {
- true => Ok(Status::Pass),
- false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))),
+pub mod status_code {
+ use axum::http::StatusCode;
+ use serde::{
+ de::{self, Unexpected, Visitor},
+ Deserializer, Serializer,
+ };
+ use std::fmt;
+
+ /// Implementation detail. Use derive annotations instead.
+ #[inline]
+ pub fn serialize<S: Serializer>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
+ ser.serialize_u16(status.as_u16())
+ }
+
+ pub(crate) struct StatusVisitor;
+
+ impl StatusVisitor {
+ #[inline(never)]
+ fn make<E: de::Error>(&self, val: u64) -> Result<StatusCode, E> {
+ if (100..1000).contains(&val) {
+ if let Ok(s) = StatusCode::from_u16(val as u16) {
+ return Ok(s);
+ }
+ }
+ Err(de::Error::invalid_value(Unexpected::Unsigned(val), self))
+ }
+ }
+
+ impl<'de> Visitor<'de> for StatusVisitor {
+ type Value = StatusCode;
+
+ #[inline]
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("status code")
+ }
+
+ #[inline]
+ fn visit_some<D: Deserializer<'de>>(
+ self,
+ deserializer: D,
+ ) -> Result<Self::Value, D::Error> {
+ deserializer.deserialize_u16(self)
+ }
+
+ #[inline]
+ fn visit_i64<E: de::Error>(self, val: i64) -> Result<Self::Value, E> {
+ self.make(val as _)
+ }
+
+ #[inline]
+ fn visit_u64<E: de::Error>(self, val: u64) -> Result<Self::Value, E> {
+ self.make(val)
}
}
+
+ /// Implementation detail.
+ #[inline]
+ pub fn deserialize<'de, D>(de: D) -> Result<StatusCode, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ de.deserialize_u16(StatusVisitor)
+ }
}
diff --git a/src/service/systemd.rs b/src/service/systemd.rs
index 45f3bf9..e3b4d1b 100644
--- a/src/service/systemd.rs
+++ b/src/service/systemd.rs
@@ -1,9 +1,12 @@
-use std::{fmt::Display, process::Command};
+use std::{fmt::Display, process::Command, time::Duration};
use serde::Deserialize;
+use tokio::sync::watch::Sender;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Systemd {
pub service: String,
@@ -15,6 +18,32 @@ impl Display for Systemd {
}
}
+impl ServiceSpawner for Systemd {
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let mut command = Command::new("systemctl");
+ command.arg("is-active").arg(&self.service);
+
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+
+ let status = match command.output() {
+ Ok(output) if output.status.success() => Status::Pass,
+ Ok(output) => {
+ let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
+ Status::Fail(Some(format!("Service state: {}", stdout)))
+ }
+ Err(err) => {
+ tracing::error!("Failed to spawn process: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ }
+}
+
impl Systemd {
pub async fn check(&self) -> Result<Status, Error> {
let output = Command::new("systemctl")
diff --git a/src/service/tcp.rs b/src/service/tcp.rs
index 87e696a..5ec5f36 100644
--- a/src/service/tcp.rs
+++ b/src/service/tcp.rs
@@ -1,12 +1,17 @@
-use std::fmt::Display;
+use std::{fmt::Display, net::SocketAddr, time::Duration};
+use futures::Stream;
use serde::Deserialize;
+use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender};
+use tokio_stream::wrappers::WatchStream;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Tcp {
- pub address: String,
+ pub address: SocketAddr,
}
impl Display for Tcp {
@@ -15,11 +20,71 @@ impl Display for Tcp {
}
}
+impl ServiceSpawner for Tcp {
+ #[tracing::instrument(skip(tx))]
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+
+ loop {
+ interval.tick().await;
+
+ let sock = TcpSocket::new_v4()?;
+ sock.set_keepalive(true)?;
+
+ match sock.connect(self.address).await {
+ Ok(conn) => {
+ tracing::info!("Connected");
+ tx.send_if_modified(|s| s.update(Status::Pass));
+ conn.ready(Interest::ERROR).await?;
+ tx.send_replace(Status::Fail(Some("Disconnected".into())));
+ tracing::info!("Disconnected");
+ }
+ Err(err) => {
+ tracing::error!("Failed to connect");
+ tx.send_if_modified(|s| s.update(err.into()));
+ }
+ };
+ }
+ }
+}
+
impl Tcp {
- pub async fn check(&self) -> Result<Status, Error> {
- Ok(std::net::TcpStream::connect(&self.address)
- .err()
- .map(Into::into)
- .unwrap_or_default())
+ pub fn into_stream(self) -> impl Stream<Item = Status> {
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+ tokio::spawn(self.spawn(tx));
+ WatchStream::new(rx)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[tokio::test]
+ #[tracing_test::traced_test]
+ #[ignore]
+ async fn test_tcp_watch() {
+ let (tx, mut rx) = tokio::sync::watch::channel(Status::default());
+
+ let tests = tokio::spawn(async move {
+ assert!(matches!(*rx.borrow_and_update(), Status::Unknown));
+
+ rx.changed().await.unwrap();
+ assert!(matches!(*rx.borrow_and_update(), Status::Pass));
+
+ rx.changed().await.unwrap();
+ assert_eq!(
+ *rx.borrow_and_update(),
+ Status::Fail(Some(String::from("Disconnected")))
+ );
+ });
+
+ let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let address = listener.local_addr().unwrap();
+ tokio::spawn(async move { Tcp { address }.spawn(tx).await });
+ listener.accept().await.unwrap();
+ drop(listener);
+
+ tests.await.unwrap()
}
}