use std::{fmt::Display, time::Duration}; use async_stream::try_stream; use axum::http::status::StatusCode; use futures::Stream; use serde::Deserialize; use url::Url; use super::IntoService; #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Request error: {0}")] Reqwest(#[from] reqwest::Error), #[error("Bad status code: {0}")] StatusCode(u16), } #[derive(Debug, Clone, Deserialize)] pub struct Http { pub url: Url, #[serde(default)] pub method: Method, #[serde(default, with = "status_code")] pub status_code: StatusCode, #[serde(skip, default)] pub client: Option, #[serde(default = "super::default_interval")] pub interval: Duration, } impl IntoService for Http { type Error = Error; fn into_service(self) -> impl Stream> { let mut interval = tokio::time::interval(self.interval); try_stream! { let client = self.client.unwrap_or_default(); let req = client.request(self.method.into(), self.url).build()?; loop { interval.tick().await; let req = req .try_clone() .expect("Clone with no body should never fail"); let status_code = client.execute(req).await?.status().as_u16(); if status_code == self.status_code { yield (); } else { Err(Error::StatusCode(status_code))? } } } } } #[derive(Debug, Clone, Copy, Default, Deserialize)] pub enum Method { #[serde(alias = "get", alias = "GET")] #[default] Get, #[serde(alias = "post", alias = "POST")] Post, } impl From for reqwest::Method { fn from(value: Method) -> Self { match value { Method::Get => reqwest::Method::GET, Method::Post => reqwest::Method::POST, } } } impl Display for Method { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Method::Get => write!(f, "GET"), Method::Post => write!(f, "POST"), } } } pub mod status_code { use axum::http::StatusCode; use serde::{ de::{self, Unexpected, Visitor}, Deserializer, Serializer, }; use std::fmt; /// Implementation detail. Use derive annotations instead. #[inline] pub fn serialize(status: &StatusCode, ser: S) -> Result { ser.serialize_u16(status.as_u16()) } pub(crate) struct StatusVisitor; impl StatusVisitor { #[inline(never)] fn make(&self, val: u64) -> Result { if (100..1000).contains(&val) { if let Ok(s) = StatusCode::from_u16(val as u16) { return Ok(s); } } Err(de::Error::invalid_value(Unexpected::Unsigned(val), self)) } } impl<'de> Visitor<'de> for StatusVisitor { type Value = StatusCode; #[inline] fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("status code") } #[inline] fn visit_some>( self, deserializer: D, ) -> Result { deserializer.deserialize_u16(self) } #[inline] fn visit_i64(self, val: i64) -> Result { self.make(val as _) } #[inline] fn visit_u64(self, val: u64) -> Result { self.make(val) } } /// Implementation detail. #[inline] pub fn deserialize<'de, D>(de: D) -> Result where D: Deserializer<'de>, { de.deserialize_u16(StatusVisitor) } }