summaryrefslogtreecommitdiffstats
path: root/src/service/http.rs
diff options
context:
space:
mode:
authorToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:54:46 -0500
committerToby Vincent <tobyv@tobyvin.dev>2024-09-28 00:58:45 -0500
commitcd774827dd14f68d8405c45d2d9da30b3fab050e (patch)
treea24e1cabb99170caa25edff53fc978111a1c9dd4 /src/service/http.rs
parent04c7f7609e5bc3fadf95c53b37a9e6e12c4e539c (diff)
feat: refactor into pub-sub and impl SSE
Diffstat (limited to 'src/service/http.rs')
-rw-r--r--src/service/http.rs183
1 files changed, 159 insertions, 24 deletions
diff --git a/src/service/http.rs b/src/service/http.rs
index fb3ff13..6d21cb7 100644
--- a/src/service/http.rs
+++ b/src/service/http.rs
@@ -1,46 +1,181 @@
-use std::fmt::Display;
+use std::{fmt::Display, time::Duration};
+use axum::http::status::StatusCode;
+use futures::Stream;
use serde::Deserialize;
+use tokio::sync::watch::Sender;
+use tokio_stream::wrappers::WatchStream;
+use url::Url;
use crate::{Error, Status};
+use super::ServiceSpawner;
+
#[derive(Debug, Clone, Deserialize)]
pub struct Http {
- pub url: String,
- #[serde(default = "Http::default_method")]
- pub method: String,
- #[serde(default = "Http::default_code")]
- pub status_code: u16,
+ pub url: Url,
+ #[serde(default)]
+ pub method: Method,
+ #[serde(default, with = "status_code")]
+ pub status_code: StatusCode,
+ #[serde(skip, default)]
+ pub client: Option<reqwest::Client>,
+}
+
+impl Display for Http {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{} {}", self.method, self.url)
+ }
+}
+
+impl ServiceSpawner for Http {
+ async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
+ let client = self.client.unwrap_or_default();
+ let request = client.request(self.method.into(), self.url).build()?;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ }
}
impl Http {
- fn default_method() -> String {
- "GET".to_string()
+ pub fn into_stream(self, client: reqwest::Client) -> impl Stream<Item = Status> {
+ let request = client
+ .request(self.method.into(), self.url)
+ .build()
+ .expect("Url parsing should not fail");
+
+ let (tx, rx) = tokio::sync::watch::channel(Status::default());
+
+ tokio::spawn(async move {
+ let mut interval = tokio::time::interval(Duration::from_secs(5));
+ loop {
+ interval.tick().await;
+ let req = request
+ .try_clone()
+ .expect("Clone with no body should never fail");
+ let resp = client.execute(req).await;
+ let status = match resp.map(|r| r.status().as_u16()) {
+ Ok(code) if code == self.status_code => Status::Pass,
+ Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
+ Err(err) => {
+ tracing::error!("HTTP request error: {err}");
+ Status::Unknown
+ }
+ };
+
+ tx.send_if_modified(|s| s.update(status));
+ }
+ });
+
+ WatchStream::new(rx)
}
+}
- fn default_code() -> u16 {
- 200
+#[derive(Debug, Clone, Copy, Default, Deserialize)]
+pub enum Method {
+ #[serde(alias = "get", alias = "GET")]
+ #[default]
+ Get,
+ #[serde(alias = "post", alias = "POST")]
+ Post,
+}
+
+impl From<Method> for reqwest::Method {
+ fn from(value: Method) -> Self {
+ match value {
+ Method::Get => reqwest::Method::GET,
+ Method::Post => reqwest::Method::POST,
+ }
}
}
-impl Display for Http {
+impl Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{} {}", self.method, self.url)
+ match self {
+ Method::Get => write!(f, "GET"),
+ Method::Post => write!(f, "POST"),
+ }
}
}
-impl Http {
- pub async fn check(&self, client: reqwest::Client) -> Result<Status, Error> {
- let status_code = client
- .request(self.method.parse().map_err(|_| Error::Method)?, &self.url)
- .send()
- .await?
- .status()
- .as_u16();
-
- match status_code == self.status_code {
- true => Ok(Status::Pass),
- false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))),
+pub mod status_code {
+ use axum::http::StatusCode;
+ use serde::{
+ de::{self, Unexpected, Visitor},
+ Deserializer, Serializer,
+ };
+ use std::fmt;
+
+ /// Implementation detail. Use derive annotations instead.
+ #[inline]
+ pub fn serialize<S: Serializer>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
+ ser.serialize_u16(status.as_u16())
+ }
+
+ pub(crate) struct StatusVisitor;
+
+ impl StatusVisitor {
+ #[inline(never)]
+ fn make<E: de::Error>(&self, val: u64) -> Result<StatusCode, E> {
+ if (100..1000).contains(&val) {
+ if let Ok(s) = StatusCode::from_u16(val as u16) {
+ return Ok(s);
+ }
+ }
+ Err(de::Error::invalid_value(Unexpected::Unsigned(val), self))
+ }
+ }
+
+ impl<'de> Visitor<'de> for StatusVisitor {
+ type Value = StatusCode;
+
+ #[inline]
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("status code")
+ }
+
+ #[inline]
+ fn visit_some<D: Deserializer<'de>>(
+ self,
+ deserializer: D,
+ ) -> Result<Self::Value, D::Error> {
+ deserializer.deserialize_u16(self)
+ }
+
+ #[inline]
+ fn visit_i64<E: de::Error>(self, val: i64) -> Result<Self::Value, E> {
+ self.make(val as _)
+ }
+
+ #[inline]
+ fn visit_u64<E: de::Error>(self, val: u64) -> Result<Self::Value, E> {
+ self.make(val)
}
}
+
+ /// Implementation detail.
+ #[inline]
+ pub fn deserialize<'de, D>(de: D) -> Result<StatusCode, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ de.deserialize_u16(StatusVisitor)
+ }
}