summaryrefslogtreecommitdiffstats
path: root/src/service/command.rs
blob: 3535ee2e3c6d06a12da2de218942d0008f2f2bff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use std::{process::Stdio, time::Duration};

use async_stream::stream;
use futures::{Stream, StreamExt};
use serde::Deserialize;
use tokio::io::{AsyncBufReadExt, BufReader};

use super::IntoService;

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("IO error: {0}")]
    Io(#[from] std::io::Error),
    #[error("Exited with status code: {code}\n{stderr}")]
    ExitCode { code: i32, stderr: String },
    #[error("Process terminated by signal")]
    Signal,
    #[error("Serialization error: {0}")]
    Serialization(#[from] serde_json::error::Error),
    #[error("{0}")]
    Stderr(String),
    #[error("{0}")]
    Output(String),
    #[error("Exited with status code: {0}")]
    PersistExitCode(i32),
    #[error("Failed to get stderr of child process")]
    NoStderr,
    #[error("Failed to get stdout of child process")]
    NoStdout,
}

#[derive(Debug, Clone, Deserialize)]
pub struct Command {
    pub command: String,
    pub args: Vec<String>,
    #[serde(default)]
    pub persist: bool,
    #[serde(default = "super::default_interval")]
    pub interval: Duration,
}

impl Command {
    #[tracing::instrument]
    fn persist(
        mut interval: tokio::time::Interval,
        mut command: tokio::process::Command,
    ) -> impl Stream<Item = Result<(), Error>> {
        stream! {
            loop {
                interval.tick().await;

                let mut child = command
                    .stdout(Stdio::piped())
                    .stderr(Stdio::piped())
                    .spawn()?;

                let mut stdout_reader =
                    BufReader::new(child.stdout.take().ok_or(Error::NoStdout)?).lines();

                while let Some(line) = stdout_reader.next_line().await? {
                    if "Ok" == line {
                        yield Ok(());
                    } else {
                        yield Err(Error::Output(line))
                    }
                }

                match child.wait().await?.code() {
                    Some(0) => yield Ok(()),
                    Some(code) => yield Err(Error::PersistExitCode(code)),
                    None => yield Err(Error::Signal),
                };
            }
        }
    }

    #[tracing::instrument]
    fn interval(
        mut interval: tokio::time::Interval,
        mut command: tokio::process::Command,
    ) -> impl Stream<Item = Result<(), Error>> {
        stream! {
            loop {
                interval.tick().await;
                let output = command.output().await?;
                match output.status.code() {
                    Some(0) => yield Ok(()),
                    Some(code) => {
                        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
                        yield Err(Error::ExitCode { code, stderr })
                    }
                    None => yield Err(Error::Signal),
                }
            }
        }
    }
}

impl IntoService for Command {
    type Error = Error;

    fn into_service(self) -> impl Stream<Item = Result<(), Self::Error>> {
        let interval = tokio::time::interval(self.interval);
        let mut command = tokio::process::Command::new(self.command);
        command.args(self.args);

        if self.persist {
            Self::persist(interval, command).boxed()
        } else {
            Self::interval(interval, command).boxed()
        }
    }
}