diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 104 | ||||
-rw-r--r-- | src/error.rs | 30 | ||||
-rw-r--r-- | src/lib.rs | 66 | ||||
-rw-r--r-- | src/main.rs | 89 | ||||
-rw-r--r-- | src/netlink.rs | 46 | ||||
-rw-r--r-- | src/tmux.rs | 53 |
6 files changed, 150 insertions, 238 deletions
diff --git a/src/config.rs b/src/config.rs index c5f6db5..44d500e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, str::FromStr, sync::atomic::AtomicBool}; +use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::atomic::AtomicBool}; use clap::Parser; #[derive(Debug, Clone, Parser)] @@ -11,72 +11,74 @@ pub struct Config { #[arg(short, long)] pub resolve: bool, - /// include host in output - #[arg(short, long)] - pub include: Vec<sshr::Host>, + /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or + /// stdin, respectivly. + #[arg(short, long, id = "NAME")] + pub include: Vec<IncludeExclude>, - /// include lines from file, use '-' for stdin - #[arg(short = 'I', long)] - pub include_file: Vec<FileOrStdin>, + /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or + /// stdin, respectivly. + #[arg(short, long, id = "HOST")] + pub exclude: Vec<IncludeExclude>, +} - /// exclude host from output - #[arg(short, long)] - pub exclude: Vec<sshr::Host>, +impl Config { + pub fn included(&self) -> std::io::Result<HashSet<String>> { + Self::collect_values(&self.include) + } - /// include lines from file, use '-' for stdin - #[arg(short = 'E', long)] - pub exclude_file: Vec<FileOrStdin>, + pub fn excluded(&self) -> std::io::Result<HashSet<String>> { + Self::collect_values(&self.exclude) + } + + fn collect_values(values: &[IncludeExclude]) -> std::io::Result<HashSet<String>> { + use std::io::BufRead; + values.iter().try_fold(HashSet::new(), |mut acc, item| { + match item { + IncludeExclude::Stdin => { + acc.extend(std::io::stdin().lock().lines().map_while(Result::ok)) + } + IncludeExclude::File(filepath) => acc.extend( + std::io::BufReader::new(std::fs::File::open(filepath)?) + .lines() + .map_while(Result::ok), + ), + IncludeExclude::Item(s) => { + acc.insert(s.to_owned()); + } + } + Ok(acc) + }) + } } static READ_STDIN: AtomicBool = AtomicBool::new(false); #[derive(Debug, Clone)] -pub enum FileOrStdin { +pub enum IncludeExclude { Stdin, + Item(String), File(PathBuf), } -impl FileOrStdin { - pub fn hosts(self) -> Result<Vec<sshr::Host>, anyhow::Error> { - use std::io::Read; - let mut buf = String::new(); - let _ = self.into_reader()?.read_to_string(&mut buf)?; - buf.lines() - .map(|s| s.trim_end().parse()) - .collect::<Result<Vec<_>, _>>() - .map_err(|e| anyhow::format_err!("{e}")) - } - - pub fn into_reader(&self) -> Result<impl std::io::Read, anyhow::Error> { - let input: Box<dyn std::io::Read + 'static> = match &self { - Self::Stdin => Box::new(std::io::stdin()), - Self::File(filepath) => { - let f = std::fs::File::open(filepath)?; - Box::new(f) - } - }; - Ok(input) - } -} - -impl FromStr for FileOrStdin { +impl FromStr for IncludeExclude { type Err = std::io::Error; fn from_str(s: &str) -> Result<Self, Self::Err> { - match s { - "-" => { - if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "stdin argument used more than once", - )); - } - READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(Self::Stdin) + if s == "-" { + if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "stdin argument used more than once", + )); + } + READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst); + Ok(Self::Stdin) + } else { + match PathBuf::from_str(s) { + Ok(path) if path.exists() => Ok(Self::File(path)), + _ => Ok(Self::Item(s.to_owned())), } - path => PathBuf::from_str(path) - .map(Self::File) - .map_err(|_| unreachable!()), } } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..9727a81 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,30 @@ +pub type Result<T, E = Error> = std::result::Result<T, E>; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Netlink error: {0:?}")] + Netlink(#[from] rtnetlink::Error), + + #[error("IO error: {0}")] + IO(#[from] std::io::Error), + + #[error("Send error: {0}")] + Send(String), + + #[error("Lock error: {0}")] + Lock(String), +} + +impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error { + fn from(err: tokio::sync::mpsc::error::SendError<T>) -> Self { + Self::Send(err.to_string()) + } +} + +impl<T> From<std::sync::PoisonError<T>> for Error { + fn from(err: std::sync::PoisonError<T>) -> Self { + Self::Lock(err.to_string()) + } +} @@ -1,66 +1,4 @@ -use std::{ - convert::Infallible, - fmt::Display, - net::{IpAddr, Ipv4Addr, Ipv6Addr}, - str::FromStr, -}; +pub use error::{Error, Result}; +pub mod error; pub mod netlink; -pub mod tmux; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Host { - IpAddr(IpAddr), - Hostname(String), -} - -impl Host { - pub fn resolve(&self) -> std::io::Result<Self> { - match self { - Self::IpAddr(ip) => dns_lookup::lookup_addr(ip).map(Self::Hostname), - Self::Hostname(h) => dns_lookup::lookup_host(h)? - .first() - .cloned() - .map(Self::IpAddr) - .ok_or(std::io::Error::new( - std::io::ErrorKind::NotFound, - "resolution not enabled", - )), - } - } -} - -impl FromStr for Host { - type Err = Infallible; - - fn from_str(s: &str) -> Result<Self, Self::Err> { - Ok(IpAddr::from_str(s).map_or_else(|_| Self::Hostname(s.to_string()), Host::from)) - } -} - -impl From<IpAddr> for Host { - fn from(value: IpAddr) -> Self { - Self::IpAddr(value) - } -} - -impl From<Ipv4Addr> for Host { - fn from(value: Ipv4Addr) -> Self { - IpAddr::from(value).into() - } -} - -impl From<Ipv6Addr> for Host { - fn from(value: Ipv6Addr) -> Self { - IpAddr::from(value).into() - } -} - -impl Display for Host { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Host::IpAddr(i) => write!(f, "{}", i), - Host::Hostname(s) => write!(f, "{}", s), - } - } -} diff --git a/src/main.rs b/src/main.rs index 5fc8435..7dd9390 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,66 +1,75 @@ use std::{ - collections::HashSet, io::{stdout, Write}, + net::IpAddr, + str::FromStr, + sync::{Arc, Mutex}, }; use clap::Parser; use config::Config; -use tokio::sync::mpsc; +use dns_lookup::{lookup_addr, lookup_host}; +use tokio::{sync::mpsc, task::JoinSet}; mod config; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { let config = Config::parse(); + let mut join_set = JoinSet::new(); + let excluded = config.excluded()?; + + let cache = Arc::new(Mutex::new(excluded)); let (tx, mut rx) = mpsc::channel(100); - for host in config - .include_file - .into_iter() - .flat_map(|i| i.hosts()) - .flatten() - .chain(config.include) - { - tx.send(host).await?; - } + for host in config.included()? { + match IpAddr::from_str(&host) { + Ok(ip) => tx.send(ip).await?, + Err(_) => { + if cache.lock().unwrap().insert(host.clone()) { + let mut stdout = stdout().lock(); + writeln!(stdout, "{host}")?; + } - sshr::tmux::sessions(tx.clone(), &config.socket).await?; + if config.resolve { + let cache = cache.clone(); + join_set.spawn(async move { + for ip in lookup_host(&host)? { + cache.lock()?.insert(ip.to_string()); + } + Ok(()) + }); + } + } + } + } - tokio::spawn(sshr::netlink::neighbours(tx.clone())); + join_set.spawn(sshr::netlink::neighbours(tx.clone())); drop(tx); - let mut cache = HashSet::new(); - - for host in config - .exclude_file - .into_iter() - .flat_map(|i| i.hosts()) - .flatten() - .chain(config.exclude) - { - cache.insert(host); - } - - let mut stdout = stdout(); - while let Some(host) = rx.recv().await { - if !cache.insert(host.clone()) { - continue; - } + while let Some(ip_addr) = rx.recv().await { + join_set.spawn({ + let cache = cache.clone(); + async move { + let s = if config.resolve { + lookup_addr(&ip_addr).unwrap_or_else(|_| ip_addr.to_string()) + } else { + ip_addr.to_string() + }; - let resolved = config.resolve.then_some(host.resolve().ok()).flatten(); - - if let Some(res) = resolved.clone() { - if !cache.insert(res) { - continue; + if cache.lock()?.insert(s.clone()) { + let mut stdout = stdout().lock(); + writeln!(stdout, "{s}")?; + } + Ok(()) } - } + }); + } - if let Some(sshr::Host::Hostname(r)) = resolved { - writeln!(stdout, "{r}")?; - } else { - writeln!(stdout, "{host}")?; + while let Some(res) = join_set.join_next().await { + if let Err(err) = res { + eprintln!("{err}") } } diff --git a/src/netlink.rs b/src/netlink.rs index 550f84e..feaac90 100644 --- a/src/netlink.rs +++ b/src/netlink.rs @@ -1,4 +1,6 @@ -use futures::{stream::TryStreamExt, FutureExt}; +use std::net::IpAddr; + +use futures::stream::TryStreamExt; use netlink_packet_route::{ neighbour::{NeighbourAddress, NeighbourAttribute, NeighbourMessage}, route::RouteType, @@ -6,11 +8,11 @@ use netlink_packet_route::{ use rtnetlink::{new_connection, IpVersion}; use tokio::sync::mpsc::Sender; -use crate::Host; +use crate::{Error, Result}; pub struct Netlink; -pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> { +pub async fn neighbours(tx: Sender<IpAddr>) -> Result<()> { let (connection, handle, _) = new_connection().unwrap(); tokio::spawn(connection); @@ -19,39 +21,23 @@ pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> { .get() .set_family(IpVersion::V4) .execute() - .try_filter_map(|r| async move { Ok(filter(r)) }) - .try_for_each(|host| tx.send(host).then(|_| async { Ok(()) })) + .or_else(|res| async { Err(Error::from(res)) }) + .try_filter_map(|msg| async { Ok(filter(msg)) }) + .try_for_each(|host| { + let tx = tx.clone(); + async move { tx.send(host).await.map_err(Error::from) } + }) .await } -pub fn filter(route: NeighbourMessage) -> Option<Host> { - if route.header.kind != RouteType::Unicast { +pub fn filter(msg: NeighbourMessage) -> Option<IpAddr> { + if msg.header.kind != RouteType::Unicast { return None; } - route.attributes.into_iter().find_map(|attr| match attr { - NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(Host::from(ip)), - NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(Host::from(ip)), + msg.attributes.into_iter().find_map(|attr| match attr { + NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(ip.into()), + NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(ip.into()), _ => None, }) } - -#[cfg(test)] -mod tests { - use tokio::sync::mpsc; - - use super::*; - - #[tokio::test] - async fn test_dump_neighbours() -> Result<(), ()> { - let (tx, mut rx) = mpsc::channel::<Host>(100); - - tokio::spawn(neighbours(tx)); - - while let Some(res) = rx.recv().await { - println!("{res}"); - } - - Ok(()) - } -} diff --git a/src/tmux.rs b/src/tmux.rs deleted file mode 100644 index e318c07..0000000 --- a/src/tmux.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::{cmp::Reverse, collections::BTreeMap, ffi::OsStr, process::Command}; - -use tokio::sync::mpsc::Sender; - -use crate::Host; - -pub async fn sessions<S>(tx: Sender<Host>, socket: &S) -> Result<(), anyhow::Error> -where - S: AsRef<OsStr>, -{ - let stdout = Command::new("tmux") - .arg("-L") - .arg(socket) - .arg("list-sessions") - .arg("-F") - .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{s/_/./:session_name}") - .output()? - .stdout; - - let mut btree_map: BTreeMap<Reverse<usize>, String> = std::str::from_utf8(&stdout)? - .lines() - .flat_map(|s| { - let (t, s) = s.split_once(':')?; - Some((Reverse(t.parse().ok()?), s.to_string())) - }) - .collect(); - - let stdout = Command::new("tmux") - .arg("-L") - .arg("default") - .arg("list-sessions") - .arg("-F") - .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{host}") - .output()? - .stdout; - - if let Some((t, s)) = std::str::from_utf8(&stdout)? - .lines() - .flat_map(|s| { - let (t, s) = s.split_once(':')?; - Some((t.parse().ok()?, s.to_string())) - }) - .max_by_key(|t| t.0) - { - btree_map.insert(Reverse(t), s); - } - - for name in btree_map.into_values() { - tx.send(name.parse()?).await?; - } - - Ok(()) -} |