summaryrefslogtreecommitdiffstats
path: root/src/service/http.rs
blob: 7c875b9c3e41e5f23ae55078ca447bf488c58558 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use std::{fmt::Display, time::Duration};

use axum::http::status::StatusCode;
use serde::Deserialize;
use tokio::sync::watch::Sender;
use url::Url;

use crate::{Error, Status};

use super::ServiceSpawner;

#[derive(Debug, Clone, Deserialize)]
pub struct Http {
    pub url: Url,
    #[serde(default)]
    pub method: Method,
    #[serde(default, with = "status_code")]
    pub status_code: StatusCode,
    #[serde(skip, default)]
    pub client: Option<reqwest::Client>,
}

impl ServiceSpawner for Http {
    async fn spawn(self, tx: Sender<Status>) -> Result<(), Error> {
        let client = self.client.unwrap_or_default();
        let request = client.request(self.method.into(), self.url).build()?;

        let mut interval = tokio::time::interval(Duration::from_secs(5));
        loop {
            interval.tick().await;
            let req = request
                .try_clone()
                .expect("Clone with no body should never fail");
            let resp = client.execute(req).await;
            let status = match resp.map(|r| r.status().as_u16()) {
                Ok(code) if code == self.status_code => Status::Pass,
                Ok(code) => Status::Fail(Some(format!("Status code: {code}"))),
                Err(err) => {
                    tracing::error!("HTTP request error: {err}");
                    Status::Unknown
                }
            };

            tx.send_if_modified(|s| s.update(status));
        }
    }
}

#[derive(Debug, Clone, Copy, Default, Deserialize)]
pub enum Method {
    #[serde(alias = "get", alias = "GET")]
    #[default]
    Get,
    #[serde(alias = "post", alias = "POST")]
    Post,
}

impl From<Method> for reqwest::Method {
    fn from(value: Method) -> Self {
        match value {
            Method::Get => reqwest::Method::GET,
            Method::Post => reqwest::Method::POST,
        }
    }
}

impl Display for Method {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Method::Get => write!(f, "GET"),
            Method::Post => write!(f, "POST"),
        }
    }
}

pub mod status_code {
    use axum::http::StatusCode;
    use serde::{
        de::{self, Unexpected, Visitor},
        Deserializer, Serializer,
    };
    use std::fmt;

    /// Implementation detail. Use derive annotations instead.
    #[inline]
    pub fn serialize<S: Serializer>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
        ser.serialize_u16(status.as_u16())
    }

    pub(crate) struct StatusVisitor;

    impl StatusVisitor {
        #[inline(never)]
        fn make<E: de::Error>(&self, val: u64) -> Result<StatusCode, E> {
            if (100..1000).contains(&val) {
                if let Ok(s) = StatusCode::from_u16(val as u16) {
                    return Ok(s);
                }
            }
            Err(de::Error::invalid_value(Unexpected::Unsigned(val), self))
        }
    }

    impl<'de> Visitor<'de> for StatusVisitor {
        type Value = StatusCode;

        #[inline]
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("status code")
        }

        #[inline]
        fn visit_some<D: Deserializer<'de>>(
            self,
            deserializer: D,
        ) -> Result<Self::Value, D::Error> {
            deserializer.deserialize_u16(self)
        }

        #[inline]
        fn visit_i64<E: de::Error>(self, val: i64) -> Result<Self::Value, E> {
            self.make(val as _)
        }

        #[inline]
        fn visit_u64<E: de::Error>(self, val: u64) -> Result<Self::Value, E> {
            self.make(val)
        }
    }

    /// Implementation detail.
    #[inline]
    pub fn deserialize<'de, D>(de: D) -> Result<StatusCode, D::Error>
    where
        D: Deserializer<'de>,
    {
        de.deserialize_u16(StatusVisitor)
    }
}