use std::{fmt::Display, net::SocketAddr, time::Duration}; use futures::Stream; use serde::Deserialize; use tokio::{io::Interest, net::TcpSocket, sync::watch::Sender}; use tokio_stream::wrappers::WatchStream; use crate::{Error, Status}; use super::ServiceSpawner; #[derive(Debug, Clone, Deserialize)] pub struct Tcp { pub address: SocketAddr, } impl Display for Tcp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "tcp://{}", self.address) } } impl ServiceSpawner for Tcp { #[tracing::instrument(skip(tx))] async fn spawn(self, tx: Sender) -> Result<(), Error> { let mut interval = tokio::time::interval(Duration::from_secs(5)); loop { interval.tick().await; let sock = TcpSocket::new_v4()?; sock.set_keepalive(true)?; match sock.connect(self.address).await { Ok(conn) => { tracing::info!("Connected"); tx.send_if_modified(|s| s.update(Status::Pass)); conn.ready(Interest::ERROR).await?; tx.send_replace(Status::Fail(Some("Disconnected".into()))); tracing::info!("Disconnected"); } Err(err) => { tracing::error!("Failed to connect"); tx.send_if_modified(|s| s.update(err.into())); } }; } } } impl Tcp { pub fn into_stream(self) -> impl Stream { let (tx, rx) = tokio::sync::watch::channel(Status::default()); tokio::spawn(self.spawn(tx)); WatchStream::new(rx) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] #[tracing_test::traced_test] #[ignore] async fn test_tcp_watch() { let (tx, mut rx) = tokio::sync::watch::channel(Status::default()); let tests = tokio::spawn(async move { assert!(matches!(*rx.borrow_and_update(), Status::Unknown)); rx.changed().await.unwrap(); assert!(matches!(*rx.borrow_and_update(), Status::Pass)); rx.changed().await.unwrap(); assert_eq!( *rx.borrow_and_update(), Status::Fail(Some(String::from("Disconnected"))) ); }); let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { Tcp { address }.spawn(tx).await }); listener.accept().await.unwrap(); drop(listener); tests.await.unwrap() } }