From cd774827dd14f68d8405c45d2d9da30b3fab050e Mon Sep 17 00:00:00 2001 From: Toby Vincent Date: Sat, 28 Sep 2024 00:54:46 -0500 Subject: feat: refactor into pub-sub and impl SSE --- src/service/http.rs | 183 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 159 insertions(+), 24 deletions(-) (limited to 'src/service/http.rs') diff --git a/src/service/http.rs b/src/service/http.rs index fb3ff13..6d21cb7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,46 +1,181 @@ -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; +use axum::http::status::StatusCode; +use futures::Stream; use serde::Deserialize; +use tokio::sync::watch::Sender; +use tokio_stream::wrappers::WatchStream; +use url::Url; use crate::{Error, Status}; +use super::ServiceSpawner; + #[derive(Debug, Clone, Deserialize)] pub struct Http { - pub url: String, - #[serde(default = "Http::default_method")] - pub method: String, - #[serde(default = "Http::default_code")] - pub status_code: u16, + pub url: Url, + #[serde(default)] + pub method: Method, + #[serde(default, with = "status_code")] + pub status_code: StatusCode, + #[serde(skip, default)] + pub client: Option, +} + +impl Display for Http { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.method, self.url) + } +} + +impl ServiceSpawner for Http { + async fn spawn(self, tx: Sender) -> Result<(), Error> { + let client = self.client.unwrap_or_default(); + let request = client.request(self.method.into(), self.url).build()?; + + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + } } impl Http { - fn default_method() -> String { - "GET".to_string() + pub fn into_stream(self, client: reqwest::Client) -> impl Stream { + let request = client + .request(self.method.into(), self.url) + .build() + .expect("Url parsing should not fail"); + + let (tx, rx) = tokio::sync::watch::channel(Status::default()); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + let req = request + .try_clone() + .expect("Clone with no body should never fail"); + let resp = client.execute(req).await; + let status = match resp.map(|r| r.status().as_u16()) { + Ok(code) if code == self.status_code => Status::Pass, + Ok(code) => Status::Fail(Some(format!("Status code: {code}"))), + Err(err) => { + tracing::error!("HTTP request error: {err}"); + Status::Unknown + } + }; + + tx.send_if_modified(|s| s.update(status)); + } + }); + + WatchStream::new(rx) } +} - fn default_code() -> u16 { - 200 +#[derive(Debug, Clone, Copy, Default, Deserialize)] +pub enum Method { + #[serde(alias = "get", alias = "GET")] + #[default] + Get, + #[serde(alias = "post", alias = "POST")] + Post, +} + +impl From for reqwest::Method { + fn from(value: Method) -> Self { + match value { + Method::Get => reqwest::Method::GET, + Method::Post => reqwest::Method::POST, + } } } -impl Display for Http { +impl Display for Method { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.method, self.url) + match self { + Method::Get => write!(f, "GET"), + Method::Post => write!(f, "POST"), + } } } -impl Http { - pub async fn check(&self, client: reqwest::Client) -> Result { - let status_code = client - .request(self.method.parse().map_err(|_| Error::Method)?, &self.url) - .send() - .await? - .status() - .as_u16(); - - match status_code == self.status_code { - true => Ok(Status::Pass), - false => Ok(Status::Fail(Some(format!("Status code: {status_code}")))), +pub mod status_code { + use axum::http::StatusCode; + use serde::{ + de::{self, Unexpected, Visitor}, + Deserializer, Serializer, + }; + use std::fmt; + + /// Implementation detail. Use derive annotations instead. + #[inline] + pub fn serialize(status: &StatusCode, ser: S) -> Result { + ser.serialize_u16(status.as_u16()) + } + + pub(crate) struct StatusVisitor; + + impl StatusVisitor { + #[inline(never)] + fn make(&self, val: u64) -> Result { + if (100..1000).contains(&val) { + if let Ok(s) = StatusCode::from_u16(val as u16) { + return Ok(s); + } + } + Err(de::Error::invalid_value(Unexpected::Unsigned(val), self)) + } + } + + impl<'de> Visitor<'de> for StatusVisitor { + type Value = StatusCode; + + #[inline] + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("status code") + } + + #[inline] + fn visit_some>( + self, + deserializer: D, + ) -> Result { + deserializer.deserialize_u16(self) + } + + #[inline] + fn visit_i64(self, val: i64) -> Result { + self.make(val as _) + } + + #[inline] + fn visit_u64(self, val: u64) -> Result { + self.make(val) } } + + /// Implementation detail. + #[inline] + pub fn deserialize<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + de.deserialize_u16(StatusVisitor) + } } -- cgit v1.2.3-70-g09d2