summaryrefslogtreecommitdiffstats
path: root/src/service
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-10-09 18:23:58 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-10-09 18:23:58 -0500
commitb94f8e694bf01f5dba9ce2c01f589463a3dfbc69 (patch)
treec787530e63fb510db31533166edf1b9ff54be62a /src/service
parent117d33fc478bf529094850b1fe40c558f04c9865 (diff)
feat!: rewrite to use traits and streams
Diffstat (limited to 'src/service')
-rw-r--r--src/service/command.rs128
-rw-r--r--src/service/http.rs54
-rw-r--r--src/service/systemd.rs36
-rw-r--r--src/service/tcp.rs75
4 files changed, 136 insertions, 157 deletions
diff --git a/src/service/command.rs b/src/service/command.rs
index 41a79b3..3535ee2 100644
--- a/src/service/command.rs
+++ b/src/service/command.rs
@@ -1,79 +1,113 @@
use std::{process::Stdio, time::Duration};
+use async_stream::stream;
+use futures::{Stream, StreamExt};
use serde::Deserialize;
-use tokio::{
- io::{AsyncBufReadExt, BufReader},
- sync::watch::Sender,
-};
+use tokio::io::{AsyncBufReadExt, BufReader};
-use crate::{Error, Status};
+use super::IntoService;
-use super::ServiceSpawner;
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("IO error: {0}")]
+ Io(#[from] std::io::Error),
+ #[error("Exited with status code: {code}\n{stderr}")]
+ ExitCode { code: i32, stderr: String },
+ #[error("Process terminated by signal")]
+ Signal,
+ #[error("Serialization error: {0}")]
+ Serialization(#[from] serde_json::error::Error),
+ #[error("{0}")]
+ Stderr(String),
+ #[error("{0}")]
+ Output(String),
+ #[error("Exited with status code: {0}")]
+ PersistExitCode(i32),
+ #[error("Failed to get stderr of child process")]
+ NoStderr,
+ #[error("Failed to get stdout of child process")]
+ NoStdout,
+}
#[derive(Debug, Clone, Deserialize)]
pub struct Command {
pub command: String,
pub args: Vec<String>,
- pub interval: Option<Duration>,
+ #[serde(default)]
+ pub persist: bool,
+ #[serde(default = "super::default_interval")]
+ pub interval: Duration,
}
impl Command {
- async fn spawn_interval(
+ #[tracing::instrument]
+ fn persist(
+ mut interval: tokio::time::Interval,
mut command: tokio::process::Command,
- period: Duration,
- tx: Sender<Status>,
- ) -> Result<(), Error> {
- let mut interval = tokio::time::interval(period);
- loop {
- interval.tick().await;
+ ) -> impl Stream<Item = Result<(), Error>> {
+ stream! {
+ loop {
+ interval.tick().await;
+
+ let mut child = command
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .spawn()?;
- let status = command.output().await.map_or_else(Into::into, |o| {
- if o.status.success() {
- Status::Ok
- } else {
- let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string();
- Status::Error(Some(format!("Service state: {}", stdout)))
+ let mut stdout_reader =
+ BufReader::new(child.stdout.take().ok_or(Error::NoStdout)?).lines();
+
+ while let Some(line) = stdout_reader.next_line().await? {
+ if "Ok" == line {
+ yield Ok(());
+ } else {
+ yield Err(Error::Output(line))
+ }
}
- });
- tx.send_if_modified(|s| s.update(status));
+ match child.wait().await?.code() {
+ Some(0) => yield Ok(()),
+ Some(code) => yield Err(Error::PersistExitCode(code)),
+ None => yield Err(Error::Signal),
+ };
+ }
}
}
- async fn spawn_persist(
+ #[tracing::instrument]
+ fn interval(
+ mut interval: tokio::time::Interval,
mut command: tokio::process::Command,
- tx: Sender<Status>,
- ) -> Result<(), Error> {
- let mut child = command.stdout(Stdio::piped()).spawn()?;
- let mut stdout = BufReader::new(child.stdout.take().unwrap()).lines();
-
- while let Some(line) = stdout.next_line().await? {
- let status: Status = serde_json::from_str(&line)
- .unwrap_or_else(|err| Status::Error(Some(format!("Serialization error: {err}"))));
- tx.send_if_modified(|s| s.update(status));
+ ) -> impl Stream<Item = Result<(), Error>> {
+ stream! {
+ loop {
+ interval.tick().await;
+ let output = command.output().await?;
+ match output.status.code() {
+ Some(0) => yield Ok(()),
+ Some(code) => {
+ let stderr = String::from_utf8_lossy(&output.stderr).to_string();
+ yield Err(Error::ExitCode { code, stderr })
+ }
+ None => yield Err(Error::Signal),
+ }
+ }
}
-
- let exit_status = child.wait().await?;
- let status = match exit_status.code() {
- Some(0) => Status::Ok,
- Some(code) => Status::Error(Some(format!("Exited with status code: {code}"))),
- None => Status::Error(Some("Process terminated by signal".to_string())),
- };
-
- tx.send_if_modified(|s| s.update(status));
- Ok(())
}
}
-impl ServiceSpawner for Command {
- async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+impl IntoService for Command {
+ type Error = Error;
+
+ fn into_service(self) -> impl Stream<Item = Result<(), Self::Error>> {
+ let interval = tokio::time::interval(self.interval);
let mut command = tokio::process::Command::new(self.command);
command.args(self.args);
- if let Some(period) = self.interval {
- Self::spawn_interval(command, period, tx).await
+ if self.persist {
+ Self::persist(interval, command).boxed()
} else {
- Self::spawn_persist(command, tx).await
+ Self::interval(interval, command).boxed()
}
}
}
diff --git a/src/service/http.rs b/src/service/http.rs
index 8950096..c4fcee7 100644
--- a/src/service/http.rs
+++ b/src/service/http.rs
@@ -1,13 +1,20 @@
use std::{fmt::Display, time::Duration};
+use async_stream::try_stream;
use axum::http::status::StatusCode;
+use futures::Stream;
use serde::Deserialize;
-use tokio::sync::watch::Sender;
use url::Url;
-use crate::{Error, Status};
+use super::IntoService;
-use super::ServiceSpawner;
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("Request error: {0}")]
+ Reqwest(#[from] reqwest::Error),
+ #[error("Bad status code: {0}")]
+ StatusCode(u16),
+}
#[derive(Debug, Clone, Deserialize)]
pub struct Http {
@@ -18,26 +25,31 @@ pub struct Http {
pub status_code: StatusCode,
#[serde(skip, default)]
pub client: Option<reqwest::Client>,
+ #[serde(default = "super::default_interval")]
+ pub interval: Duration,
}
-impl ServiceSpawner for Http {
- async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
- let client = self.client.unwrap_or_default();
- let request = client.request(self.method.into(), self.url).build()?;
-
- let mut interval = tokio::time::interval(Duration::from_secs(5));
- loop {
- interval.tick().await;
- let req = request
- .try_clone()
- .expect("Clone with no body should never fail");
- let resp = client.execute(req).await;
- let status = resp.map_or_else(Into::into, |r| match r.status().as_u16() {
- c if c == self.status_code => Status::Ok,
- c => Status::Error(Some(format!("Status code: {c}"))),
- });
-
- tx.send_if_modified(|s| s.update(status));
+impl IntoService for Http {
+ type Error = Error;
+
+ fn into_service(self) -> impl Stream<Item = Result<(), Self::Error>> {
+ let mut interval = tokio::time::interval(self.interval);
+
+ try_stream! {
+ let client = self.client.unwrap_or_default();
+ let req = client.request(self.method.into(), self.url).build()?;
+ loop {
+ interval.tick().await;
+ let req = req
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let status_code = client.execute(req).await?.status().as_u16();
+ if status_code == self.status_code {
+ yield ();
+ } else {
+ Err(Error::StatusCode(status_code))?
+ }
+ }
}
}
}
diff --git a/src/service/systemd.rs b/src/service/systemd.rs
deleted file mode 100644
index ee220b8..0000000
--- a/src/service/systemd.rs
+++ /dev/null
@@ -1,36 +0,0 @@
-use std::{process::Command, time::Duration};
-
-use serde::Deserialize;
-use tokio::sync::watch::Sender;
-
-use crate::{Error, Status};
-
-use super::ServiceSpawner;
-
-#[derive(Debug, Clone, Deserialize)]
-pub struct Systemd {
- pub service: String,
-}
-
-impl ServiceSpawner for Systemd {
- async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
- let mut command = Command::new("systemctl");
- command.arg("is-active").arg(&self.service);
-
- let mut interval = tokio::time::interval(Duration::from_secs(5));
- loop {
- interval.tick().await;
-
- let status = command.output().map_or_else(Into::into, |o| {
- if o.status.success() {
- Status::Ok
- } else {
- let stdout = String::from_utf8_lossy(&o.stdout).trim().to_string();
- Status::Error(Some(format!("Service state: {}", stdout)))
- }
- });
-
- tx.send_if_modified(|s| s.update(status));
- }
- }
-}
diff --git a/src/service/tcp.rs b/src/service/tcp.rs
index 7b79afd..6556af0 100644
--- a/src/service/tcp.rs
+++ b/src/service/tcp.rs
@@ -1,15 +1,19 @@
use std::{fmt::Display, net::SocketAddr, time::Duration};
+use async_stream::try_stream;
+use futures::Stream;
use serde::Deserialize;
-use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender};
+use tokio::{io::Interest, net::TcpSocket};
-use crate::{Error, Status};
+use super::IntoService;
-use super::ServiceSpawner;
+pub(crate) type Error = std::io::Error;
#[derive(Debug, Clone, Deserialize)]
pub struct Tcp {
pub address: SocketAddr,
+ #[serde(default = "super::default_interval")]
+ pub interval: Duration,
}
impl Display for Tcp {
@@ -18,59 +22,24 @@ impl Display for Tcp {
}
}
-impl ServiceSpawner for Tcp {
- async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
- let mut interval = tokio::time::interval(Duration::from_secs(5));
+impl IntoService for Tcp {
+ type Error = Error;
- loop {
- interval.tick().await;
+ fn into_service(self) -> impl Stream<Item = Result<(), Self::Error>> {
+ let mut interval = tokio::time::interval(self.interval);
- let sock = TcpSocket::new_v4()?;
- sock.set_keepalive(true)?;
+ try_stream! {
+ loop {
+ interval.tick().await;
- match sock.connect(self.address).await {
- Ok(conn) => {
- // TODO: figure out how to wait for connection to close
- conn.ready(Interest::READABLE).await?;
- tx.send_if_modified(|s| s.update(Status::Ok));
- }
- Err(err) => {
- tx.send_if_modified(|s| s.update(err.into()));
- }
- };
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- #[tracing_test::traced_test]
- #[ignore]
- async fn test_tcp_watch() {
- let (tx, mut rx) = tokio::sync::watch::channel(Status::Error(None));
+ let sock = TcpSocket::new_v4()?;
+ sock.set_keepalive(true)?;
- let tests = tokio::spawn(async move {
- assert!(matches!(*rx.borrow_and_update(), Status::Error(None)));
-
- rx.changed().await.unwrap();
- assert!(matches!(*rx.borrow_and_update(), Status::Ok));
-
- rx.changed().await.unwrap();
- assert_eq!(
- *rx.borrow_and_update(),
- Status::Error(Some(String::from("Disconnected")))
- );
- });
-
- let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
- let address = listener.local_addr().unwrap();
- tokio::spawn(async move { Tcp { address }.spawn(tx).await });
- listener.accept().await.unwrap();
- drop(listener);
-
- tests.await.unwrap()
+ let conn = sock.connect(self.address).await?;
+ // TODO: figure out how to wait for connection to close
+ conn.ready(Interest::READABLE).await?;
+ yield ();
+ }
+ }
}
}