summaryrefslogtreecommitdiffstats
path: root/src/service/http.rs
blob: 5b0fdd5012f887e3861dcdd8773cbce56fa7ac79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use std::{fmt::Display, time::Duration};

use axum::http::status::StatusCode;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::status::Sender;

use super::IntoService;

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("Request error: {0}")]
    Reqwest(#[from] reqwest::Error),
    #[error("Bad status code: {0}")]
    StatusCode(u16),
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Http {
    pub url: Url,
    #[serde(default)]
    pub method: Method,
    #[serde(default, with = "status_code")]
    pub status_code: StatusCode,
    #[serde(skip, default)]
    pub client: Option<reqwest::Client>,
    #[serde(default = "super::default_interval")]
    pub interval: Duration,
}

impl Http {
    async fn check(&self) -> Result<(), Error> {
        let client = match self.client.as_ref() {
            Some(client) => client,
            None => &Client::new(),
        };
        let req = client
            .request(self.method.into(), self.url.clone())
            .build()?;
        let status_code = client.execute(req).await?.status().as_u16();
        (status_code == self.status_code)
            .then_some(())
            .ok_or_else(|| Error::StatusCode(status_code))
    }
}

impl IntoService for Http {
    async fn into_service(self, tx: Sender) {
        let mut interval = tokio::time::interval(self.interval);
        loop {
            interval.tick().await;
            let res = self.check().await;
            tx.send_if_modified(|s| s.update(res.into()));
        }
    }
}

#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)]
pub enum Method {
    #[serde(alias = "get", alias = "GET")]
    #[default]
    Get,
    #[serde(alias = "post", alias = "POST")]
    Post,
}

impl From<Method> for reqwest::Method {
    fn from(value: Method) -> Self {
        match value {
            Method::Get => reqwest::Method::GET,
            Method::Post => reqwest::Method::POST,
        }
    }
}

impl Display for Method {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Method::Get => write!(f, "GET"),
            Method::Post => write!(f, "POST"),
        }
    }
}

pub mod status_code {
    use axum::http::StatusCode;
    use serde::{
        de::{self, Unexpected, Visitor},
        Deserializer, Serializer,
    };
    use std::fmt;

    /// Implementation detail. Use derive annotations instead.
    #[inline]
    pub fn serialize<S: Serializer>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
        ser.serialize_u16(status.as_u16())
    }

    pub(crate) struct StatusVisitor;

    impl StatusVisitor {
        #[inline(never)]
        fn make<E: de::Error>(&self, val: u64) -> Result<StatusCode, E> {
            if (100..1000).contains(&val) {
                if let Ok(s) = StatusCode::from_u16(val as u16) {
                    return Ok(s);
                }
            }
            Err(de::Error::invalid_value(Unexpected::Unsigned(val), self))
        }
    }

    impl<'de> Visitor<'de> for StatusVisitor {
        type Value = StatusCode;

        #[inline]
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("status code")
        }

        #[inline]
        fn visit_some<D: Deserializer<'de>>(
            self,
            deserializer: D,
        ) -> Result<Self::Value, D::Error> {
            deserializer.deserialize_u16(self)
        }

        #[inline]
        fn visit_i64<E: de::Error>(self, val: i64) -> Result<Self::Value, E> {
            self.make(val as _)
        }

        #[inline]
        fn visit_u64<E: de::Error>(self, val: u64) -> Result<Self::Value, E> {
            self.make(val)
        }
    }

    /// Implementation detail.
    #[inline]
    pub fn deserialize<'de, D>(de: D) -> Result<StatusCode, D::Error>
    where
        D: Deserializer<'de>,
    {
        de.deserialize_u16(StatusVisitor)
    }
}