From b94f8e694bf01f5dba9ce2c01f589463a3dfbc69 Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Wed, 9 Oct 2024 18:23:58 -0500 Subject: feat!: rewrite to use traits and streams --- src/service/command.rs | 128 +++++++++++++++++++++++++++++++------------------ src/service/http.rs | 54 +++++++++++++-------- src/service/systemd.rs | 36 -------------- src/service/tcp.rs | 75 +++++++++-------------------- 4 files changed, 136 insertions(+), 157 deletions(-) delete mode 100644 src/service/systemd.rs (limited to 'src/service') diff --git a/src/service/command.rs b/src/service/command.rs index 41a79b3..3535ee2 100644 --- a/src/service/command.rs +++ b/src/service/command.rs @@ -1,79 +1,113 @@ use std::{process::Stdio, time::Duration}; +use async_stream::stream; +use futures::{Stream, StreamExt}; use serde::Deserialize; -use tokio::{ - io::{AsyncBufReadExt, BufReader}, - sync::watch::Sender, -}; +use tokio::io::{AsyncBufReadExt, BufReader}; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Exited with status code: {code}\n{stderr}")] + ExitCode { code: i32, stderr: String }, + #[error("Process terminated by signal")] + Signal, + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::error::Error), + #[error("{0}")] + Stderr(String), + #[error("{0}")] + Output(String), + #[error("Exited with status code: {0}")] + PersistExitCode(i32), + #[error("Failed to get stderr of child process")] + NoStderr, + #[error("Failed to get stdout of child process")] + NoStdout, +} #[derive(Debug, Clone, Deserialize)] pub struct Command { pub command: String, pub args: Vec, - pub interval: Option, + #[serde(default)] + pub persist: bool, + #[serde(default = "super::default_interval")] + pub interval: Duration, } impl Command { - async fn spawn_interval( + #[tracing::instrument] + fn persist( + mut interval: tokio::time::Interval, mut command: tokio::process::Command, - period: Duration, - tx: Sender, - ) -> Result<(), Error> { - let mut interval = tokio::time::interval(period); - loop { - interval.tick().await; + ) -> impl Stream> { + stream! { + loop { + interval.tick().await; + + let mut child = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; - let status = command.output().await.map_or_else(Into::into, |o| { - if o.status.success() { - Status::Ok - } else { - let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string(); - Status::Error(Some(format!("Service state: {}", stdout))) + let mut stdout_reader = + BufReader::new(child.stdout.take().ok_or(Error::NoStdout)?).lines(); + + while let Some(line) = stdout_reader.next_line().await? { + if "Ok" == line { + yield Ok(()); + } else { + yield Err(Error::Output(line)) + } } - }); - tx.send_if_modified(|s| s.update(status)); + match child.wait().await?.code() { + Some(0) => yield Ok(()), + Some(code) => yield Err(Error::PersistExitCode(code)), + None => yield Err(Error::Signal), + }; + } } } - async fn spawn_persist( + #[tracing::instrument] + fn interval( + mut interval: tokio::time::Interval, mut command: tokio::process::Command, - tx: Sender, - ) -> Result<(), Error> { - let mut child = command.stdout(Stdio::piped()).spawn()?; - let mut stdout = BufReader::new(child.stdout.take().unwrap()).lines(); - - while let Some(line) = stdout.next_line().await? { - let status: Status = serde_json::from_str(&line) - .unwrap_or_else(|err| Status::Error(Some(format!("Serialization error: {err}")))); - tx.send_if_modified(|s| s.update(status)); + ) -> impl Stream> { + stream! { + loop { + interval.tick().await; + let output = command.output().await?; + match output.status.code() { + Some(0) => yield Ok(()), + Some(code) => { + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + yield Err(Error::ExitCode { code, stderr }) + } + None => yield Err(Error::Signal), + } + } } - - let exit_status = child.wait().await?; - let status = match exit_status.code() { - Some(0) => Status::Ok, - Some(code) => Status::Error(Some(format!("Exited with status code: {code}"))), - None => Status::Error(Some("Process terminated by signal".to_string())), - }; - - tx.send_if_modified(|s| s.update(status)); - Ok(()) } } -impl ServiceSpawner for Command { - async fn spawn(self, tx: Sender) -> Result<(), Error> { +impl IntoService for Command { + type Error = Error; + + fn into_service(self) -> impl Stream> { + let interval = tokio::time::interval(self.interval); let mut command = tokio::process::Command::new(self.command); command.args(self.args); - if let Some(period) = self.interval { - Self::spawn_interval(command, period, tx).await + if self.persist { + Self::persist(interval, command).boxed() } else { - Self::spawn_persist(command, tx).await + Self::interval(interval, command).boxed() } } } diff --git a/src/service/http.rs b/src/service/http.rs index 8950096..c4fcee7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,13 +1,20 @@ use std::{fmt::Display, time::Duration}; +use async_stream::try_stream; use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; -use tokio::sync::watch::Sender; use url::Url; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Request error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("Bad status code: {0}")] + StatusCode(u16), +} #[derive(Debug, Clone, Deserialize)] pub struct Http { @@ -18,26 +25,31 @@ pub struct Http { pub status_code: StatusCode, #[serde(skip, default)] pub client: Option, + #[serde(default = "super::default_interval")] + pub interval: Duration, } -impl ServiceSpawner for Http { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let client = self.client.unwrap_or_default(); - let request = client.request(self.method.into(), self.url).build()?; - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - loop { - interval.tick().await; - let req = request - .try_clone() - .expect("Clone with no body should never fail"); - let resp = client.execute(req).await; - let status = resp.map_or_else(Into::into, |r| match r.status().as_u16() { - c if c == self.status_code => Status::Ok, - c => Status::Error(Some(format!("Status code: {c}"))), - }); - - tx.send_if_modified(|s| s.update(status)); +impl IntoService for Http { + type Error = Error; + + fn into_service(self) -> impl Stream> { + let mut interval = tokio::time::interval(self.interval); + + try_stream! { + let client = self.client.unwrap_or_default(); + let req = client.request(self.method.into(), self.url).build()?; + loop { + interval.tick().await; + let req = req + .try_clone() + .expect("Clone with no body should never fail"); + let status_code = client.execute(req).await?.status().as_u16(); + if status_code == self.status_code { + yield (); + } else { + Err(Error::StatusCode(status_code))? + } + } } } } diff --git a/src/service/systemd.rs b/src/service/systemd.rs deleted file mode 100644 index ee220b8..0000000 --- a/src/service/systemd.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::{process::Command, time::Duration}; - -use serde::Deserialize; -use tokio::sync::watch::Sender; - -use crate::{Error, Status}; - -use super::ServiceSpawner; - -#[derive(Debug, Clone, Deserialize)] -pub struct Systemd { - pub service: String, -} - -impl ServiceSpawner for Systemd { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let mut command = Command::new("systemctl"); - command.arg("is-active").arg(&self.service); - - let mut interval = tokio::time::interval(Duration::from_secs(5)); - loop { - interval.tick().await; - - let status = command.output().map_or_else(Into::into, |o| { - if o.status.success() { - Status::Ok - } else { - let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string(); - Status::Error(Some(format!("Service state: {}", stdout))) - } - }); - - tx.send_if_modified(|s| s.update(status)); - } - } -} diff --git a/src/service/tcp.rs b/src/service/tcp.rs index 7b79afd..6556af0 100644 --- a/src/service/tcp.rs +++ b/src/service/tcp.rs @@ -1,15 +1,19 @@ use std::{fmt::Display, net::SocketAddr, time::Duration}; +use async_stream::try_stream; +use futures::Stream; use serde::Deserialize; -use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; +use tokio::{io::Interest, net::TcpSocket}; -use crate::{Error, Status}; +use super::IntoService; -use super::ServiceSpawner; +pub(crate) type Error = std::io::Error; #[derive(Debug, Clone, Deserialize)] pub struct Tcp { pub address: SocketAddr, + #[serde(default = "super::default_interval")] + pub interval: Duration, } impl Display for Tcp { @@ -18,59 +22,24 @@ impl Display for Tcp { } } -impl ServiceSpawner for Tcp { - async fn spawn(self, tx: Sender) -> Result<(), Error> { - let mut interval = tokio::time::interval(Duration::from_secs(5)); +impl IntoService for Tcp { + type Error = Error; - loop { - interval.tick().await; + fn into_service(self) -> impl Stream> { + let mut interval = tokio::time::interval(self.interval); - let sock = TcpSocket::new_v4()?; - sock.set_keepalive(true)?; + try_stream! { + loop { + interval.tick().await; - match sock.connect(self.address).await { - Ok(conn) => { - // TODO: figure out how to wait for connection to close - conn.ready(Interest::READABLE).await?; - tx.send_if_modified(|s| s.update(Status::Ok)); - } - Err(err) => { - tx.send_if_modified(|s| s.update(err.into())); - } - }; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[tracing_test::traced_test] - #[ignore] - async fn test_tcp_watch() { - let (tx, mut rx) = tokio::sync::watch::channel(Status::Error(None)); + let sock = TcpSocket::new_v4()?; + sock.set_keepalive(true)?; - let tests = tokio::spawn(async move { - assert!(matches!(*rx.borrow_and_update(), Status::Error(None))); - - rx.changed().await.unwrap(); - assert!(matches!(*rx.borrow_and_update(), Status::Ok)); - - rx.changed().await.unwrap(); - assert_eq!( - *rx.borrow_and_update(), - Status::Error(Some(String::from("Disconnected"))) - ); - }); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - tokio::spawn(async move { Tcp { address }.spawn(tx).await }); - listener.accept().await.unwrap(); - drop(listener); - - tests.await.unwrap() + let conn = sock.connect(self.address).await?; + // TODO: figure out how to wait for connection to close + conn.ready(Interest::READABLE).await?; + yield (); + } + } } } -- cgit v1.2.3-70-g09d2