diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-02-24 18:09:58 -0600 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-02-24 18:09:58 -0600 |
commit | 04814c4996140871674d7ce5552f55d9ba07615a (patch) | |
tree | 804e57beed1a4d3c2b8908523c0453f4655b5bbd | |
parent | 12cc1358ad636c194a2464561939ff72fb8aaa9c (diff) |
feat!: remove tmux functionalilty and more async
-rw-r--r-- | Cargo.lock | 107 | ||||
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | src/config.rs | 104 | ||||
-rw-r--r-- | src/error.rs | 30 | ||||
-rw-r--r-- | src/lib.rs | 66 | ||||
-rw-r--r-- | src/main.rs | 89 | ||||
-rw-r--r-- | src/netlink.rs | 46 | ||||
-rw-r--r-- | src/tmux.rs | 53 |
8 files changed, 160 insertions, 339 deletions
@@ -94,12 +94,6 @@ dependencies = [ [[package]] name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" @@ -292,9 +286,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" +checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd" [[package]] name = "libc" @@ -303,16 +297,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] name = "log" version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -415,7 +399,7 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 2.4.2", + "bitflags", "cfg-if", "libc", ] @@ -440,29 +424,6 @@ dependencies = [ ] [[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.48.5", -] - -[[package]] name = "paste" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -499,15 +460,6 @@ dependencies = [ ] [[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] name = "rtnetlink" version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -532,41 +484,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "serde" -version = "1.0.196" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.196" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" -dependencies = [ - "libc", -] - -[[package]] name = "slab" version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -576,12 +493,6 @@ dependencies = [ ] [[package]] -name = "smallvec" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" - -[[package]] name = "socket2" version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -601,7 +512,7 @@ dependencies = [ "futures", "netlink-packet-route", "rtnetlink", - "serde", + "thiserror", "tokio", ] @@ -624,18 +535,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", @@ -653,9 +564,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", "pin-project-lite", - "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", @@ -10,9 +10,9 @@ repository = "https://git.sr.ht/~tobyvin/sshr" [dependencies] anyhow = "1.0.69" clap = { version = "4.1.9", features = ["derive"] } -serde = { version = "1.0.156", features = ["derive"] } rtnetlink = "0.14.0" futures = "0.3.30" -tokio = { version = "1.35.1", features = ["full"] } +tokio = { version = "1.35.1", features = ["rt-multi-thread", "macros", "sync"] } dns-lookup = "2.0.4" netlink-packet-route = "0.19.0" +thiserror = "1.0.57" diff --git a/src/config.rs b/src/config.rs index c5f6db5..44d500e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, str::FromStr, sync::atomic::AtomicBool}; +use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::atomic::AtomicBool}; use clap::Parser; #[derive(Debug, Clone, Parser)] @@ -11,72 +11,74 @@ pub struct Config { #[arg(short, long)] pub resolve: bool, - /// include host in output - #[arg(short, long)] - pub include: Vec<sshr::Host>, + /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or + /// stdin, respectivly. + #[arg(short, long, id = "NAME")] + pub include: Vec<IncludeExclude>, - /// include lines from file, use '-' for stdin - #[arg(short = 'I', long)] - pub include_file: Vec<FileOrStdin>, + /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or + /// stdin, respectivly. + #[arg(short, long, id = "HOST")] + pub exclude: Vec<IncludeExclude>, +} - /// exclude host from output - #[arg(short, long)] - pub exclude: Vec<sshr::Host>, +impl Config { + pub fn included(&self) -> std::io::Result<HashSet<String>> { + Self::collect_values(&self.include) + } - /// include lines from file, use '-' for stdin - #[arg(short = 'E', long)] - pub exclude_file: Vec<FileOrStdin>, + pub fn excluded(&self) -> std::io::Result<HashSet<String>> { + Self::collect_values(&self.exclude) + } + + fn collect_values(values: &[IncludeExclude]) -> std::io::Result<HashSet<String>> { + use std::io::BufRead; + values.iter().try_fold(HashSet::new(), |mut acc, item| { + match item { + IncludeExclude::Stdin => { + acc.extend(std::io::stdin().lock().lines().map_while(Result::ok)) + } + IncludeExclude::File(filepath) => acc.extend( + std::io::BufReader::new(std::fs::File::open(filepath)?) + .lines() + .map_while(Result::ok), + ), + IncludeExclude::Item(s) => { + acc.insert(s.to_owned()); + } + } + Ok(acc) + }) + } } static READ_STDIN: AtomicBool = AtomicBool::new(false); #[derive(Debug, Clone)] -pub enum FileOrStdin { +pub enum IncludeExclude { Stdin, + Item(String), File(PathBuf), } -impl FileOrStdin { - pub fn hosts(self) -> Result<Vec<sshr::Host>, anyhow::Error> { - use std::io::Read; - let mut buf = String::new(); - let _ = self.into_reader()?.read_to_string(&mut buf)?; - buf.lines() - .map(|s| s.trim_end().parse()) - .collect::<Result<Vec<_>, _>>() - .map_err(|e| anyhow::format_err!("{e}")) - } - - pub fn into_reader(&self) -> Result<impl std::io::Read, anyhow::Error> { - let input: Box<dyn std::io::Read + 'static> = match &self { - Self::Stdin => Box::new(std::io::stdin()), - Self::File(filepath) => { - let f = std::fs::File::open(filepath)?; - Box::new(f) - } - }; - Ok(input) - } -} - -impl FromStr for FileOrStdin { +impl FromStr for IncludeExclude { type Err = std::io::Error; fn from_str(s: &str) -> Result<Self, Self::Err> { - match s { - "-" => { - if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "stdin argument used more than once", - )); - } - READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(Self::Stdin) + if s == "-" { + if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "stdin argument used more than once", + )); + } + READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst); + Ok(Self::Stdin) + } else { + match PathBuf::from_str(s) { + Ok(path) if path.exists() => Ok(Self::File(path)), + _ => Ok(Self::Item(s.to_owned())), } - path => PathBuf::from_str(path) - .map(Self::File) - .map_err(|_| unreachable!()), } } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..9727a81 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,30 @@ +pub type Result<T, E = Error> = std::result::Result<T, E>; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Netlink error: {0:?}")] + Netlink(#[from] rtnetlink::Error), + + #[error("IO error: {0}")] + IO(#[from] std::io::Error), + + #[error("Send error: {0}")] + Send(String), + + #[error("Lock error: {0}")] + Lock(String), +} + +impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error { + fn from(err: tokio::sync::mpsc::error::SendError<T>) -> Self { + Self::Send(err.to_string()) + } +} + +impl<T> From<std::sync::PoisonError<T>> for Error { + fn from(err: std::sync::PoisonError<T>) -> Self { + Self::Lock(err.to_string()) + } +} @@ -1,66 +1,4 @@ -use std::{ - convert::Infallible, - fmt::Display, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, - str::FromStr, -}; +pub use error::{Error, Result}; +pub mod error; pub mod netlink; -pub mod tmux; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Host { - IpAddr(IpAddr), - Hostname(String), -} - -impl Host { - pub fn resolve(&self) -> std::io::Result<Self> { - match self { - Self::IpAddr(ip) => dns_lookup::lookup_addr(ip).map(Self::Hostname), - Self::Hostname(h) => dns_lookup::lookup_host(h)? - .first() - .cloned() - .map(Self::IpAddr) - .ok_or(std::io::Error::new( - std::io::ErrorKind::NotFound, - "resolution not enabled", - )), - } - } -} - -impl FromStr for Host { - type Err = Infallible; - - fn from_str(s: &str) -> Result<Self, Self::Err> { - Ok(IpAddr::from_str(s).map_or_else(|_| Self::Hostname(s.to_string()), Host::from)) - } -} - -impl From<IpAddr> for Host { - fn from(value: IpAddr) -> Self { - Self::IpAddr(value) - } -} - -impl From<Ipv4Addr> for Host { - fn from(value: Ipv4Addr) -> Self { - IpAddr::from(value).into() - } -} - -impl From<Ipv6Addr> for Host { - fn from(value: Ipv6Addr) -> Self { - IpAddr::from(value).into() - } -} - -impl Display for Host { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Host::IpAddr(i) => write!(f, "{}", i), - Host::Hostname(s) => write!(f, "{}", s), - } - } -} diff --git a/src/main.rs b/src/main.rs index 5fc8435..7dd9390 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,66 +1,75 @@ use std::{ - collections::HashSet, io::{stdout, Write}, + net::IpAddr, + str::FromStr, + sync::{Arc, Mutex}, }; use clap::Parser; use config::Config; -use tokio::sync::mpsc; +use dns_lookup::{lookup_addr, lookup_host}; +use tokio::{sync::mpsc, task::JoinSet}; mod config; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { let config = Config::parse(); + let mut join_set = JoinSet::new(); + let excluded = config.excluded()?; + + let cache = Arc::new(Mutex::new(excluded)); let (tx, mut rx) = mpsc::channel(100); - for host in config - .include_file - .into_iter() - .flat_map(|i| i.hosts()) - .flatten() - .chain(config.include) - { - tx.send(host).await?; - } + for host in config.included()? { + match IpAddr::from_str(&host) { + Ok(ip) => tx.send(ip).await?, + Err(_) => { + if cache.lock().unwrap().insert(host.clone()) { + let mut stdout = stdout().lock(); + writeln!(stdout, "{host}")?; + } - sshr::tmux::sessions(tx.clone(), &config.socket).await?; + if config.resolve { + let cache = cache.clone(); + join_set.spawn(async move { + for ip in lookup_host(&host)? { + cache.lock()?.insert(ip.to_string()); + } + Ok(()) + }); + } + } + } + } - tokio::spawn(sshr::netlink::neighbours(tx.clone())); + join_set.spawn(sshr::netlink::neighbours(tx.clone())); drop(tx); - let mut cache = HashSet::new(); - - for host in config - .exclude_file - .into_iter() - .flat_map(|i| i.hosts()) - .flatten() - .chain(config.exclude) - { - cache.insert(host); - } - - let mut stdout = stdout(); - while let Some(host) = rx.recv().await { - if !cache.insert(host.clone()) { - continue; - } + while let Some(ip_addr) = rx.recv().await { + join_set.spawn({ + let cache = cache.clone(); + async move { + let s = if config.resolve { + lookup_addr(&ip_addr).unwrap_or_else(|_| ip_addr.to_string()) + } else { + ip_addr.to_string() + }; - let resolved = config.resolve.then_some(host.resolve().ok()).flatten(); - - if let Some(res) = resolved.clone() { - if !cache.insert(res) { - continue; + if cache.lock()?.insert(s.clone()) { + let mut stdout = stdout().lock(); + writeln!(stdout, "{s}")?; + } + Ok(()) } - } + }); + } - if let Some(sshr::Host::Hostname(r)) = resolved { - writeln!(stdout, "{r}")?; - } else { - writeln!(stdout, "{host}")?; + while let Some(res) = join_set.join_next().await { + if let Err(err) = res { + eprintln!("{err}") } } diff --git a/src/netlink.rs b/src/netlink.rs index 550f84e..feaac90 100644 --- a/src/netlink.rs +++ b/src/netlink.rs @@ -1,4 +1,6 @@ -use futures::{stream::TryStreamExt, FutureExt}; +use std::net::IpAddr; + +use futures::stream::TryStreamExt; use netlink_packet_route::{ neighbour::{NeighbourAddress, NeighbourAttribute, NeighbourMessage}, route::RouteType, @@ -6,11 +8,11 @@ use netlink_packet_route::{ use rtnetlink::{new_connection, IpVersion}; use tokio::sync::mpsc::Sender; -use crate::Host; +use crate::{Error, Result}; pub struct Netlink; -pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> { +pub async fn neighbours(tx: Sender<IpAddr>) -> Result<()> { let (connection, handle, _) = new_connection().unwrap(); tokio::spawn(connection); @@ -19,39 +21,23 @@ pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> { .get() .set_family(IpVersion::V4) .execute() - .try_filter_map(|r| async move { Ok(filter(r)) }) - .try_for_each(|host| tx.send(host).then(|_| async { Ok(()) })) + .or_else(|res| async { Err(Error::from(res)) }) + .try_filter_map(|msg| async { Ok(filter(msg)) }) + .try_for_each(|host| { + let tx = tx.clone(); + async move { tx.send(host).await.map_err(Error::from) } + }) .await } -pub fn filter(route: NeighbourMessage) -> Option<Host> { - if route.header.kind != RouteType::Unicast { +pub fn filter(msg: NeighbourMessage) -> Option<IpAddr> { + if msg.header.kind != RouteType::Unicast { return None; } - route.attributes.into_iter().find_map(|attr| match attr { - NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(Host::from(ip)), - NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(Host::from(ip)), + msg.attributes.into_iter().find_map(|attr| match attr { + NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(ip.into()), + NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(ip.into()), _ => None, }) } - -#[cfg(test)] -mod tests { - use tokio::sync::mpsc; - - use super::*; - - #[tokio::test] - async fn test_dump_neighbours() -> Result<(), ()> { - let (tx, mut rx) = mpsc::channel::<Host>(100); - - tokio::spawn(neighbours(tx)); - - while let Some(res) = rx.recv().await { - println!("{res}"); - } - - Ok(()) - } -} diff --git a/src/tmux.rs b/src/tmux.rs deleted file mode 100644 index e318c07..0000000 --- a/src/tmux.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::{cmp::Reverse, collections::BTreeMap, ffi::OsStr, process::Command}; - -use tokio::sync::mpsc::Sender; - -use crate::Host; - -pub async fn sessions<S>(tx: Sender<Host>, socket: &S) -> Result<(), anyhow::Error> -where - S: AsRef<OsStr>, -{ - let stdout = Command::new("tmux") - .arg("-L") - .arg(socket) - .arg("list-sessions") - .arg("-F") - .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{s/_/./:session_name}") - .output()? - .stdout; - - let mut btree_map: BTreeMap<Reverse<usize>, String> = std::str::from_utf8(&stdout)? - .lines() - .flat_map(|s| { - let (t, s) = s.split_once(':')?; - Some((Reverse(t.parse().ok()?), s.to_string())) - }) - .collect(); - - let stdout = Command::new("tmux") - .arg("-L") - .arg("default") - .arg("list-sessions") - .arg("-F") - .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{host}") - .output()? - .stdout; - - if let Some((t, s)) = std::str::from_utf8(&stdout)? - .lines() - .flat_map(|s| { - let (t, s) = s.split_once(':')?; - Some((t.parse().ok()?, s.to_string())) - }) - .max_by_key(|t| t.0) - { - btree_map.insert(Reverse(t), s); - } - - for name in btree_map.into_values() { - tx.send(name.parse()?).await?; - } - - Ok(()) -} |