summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.rs104
-rw-r--r--src/error.rs30
-rw-r--r--src/lib.rs66
-rw-r--r--src/main.rs89
-rw-r--r--src/netlink.rs46
-rw-r--r--src/tmux.rs53
6 files changed, 150 insertions, 238 deletions
diff --git a/src/config.rs b/src/config.rs
index c5f6db5..44d500e 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,4 +1,4 @@
-use std::{path::PathBuf, str::FromStr, sync::atomic::AtomicBool};
+use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::atomic::AtomicBool};
use clap::Parser;
#[derive(Debug, Clone, Parser)]
@@ -11,72 +11,74 @@ pub struct Config {
#[arg(short, long)]
pub resolve: bool,
- /// include host in output
- #[arg(short, long)]
- pub include: Vec<sshr::Host>,
+ /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or
+ /// stdin, respectivly.
+ #[arg(short, long, id = "NAME")]
+ pub include: Vec<IncludeExclude>,
- /// include lines from file, use '-' for stdin
- #[arg(short = 'I', long)]
- pub include_file: Vec<FileOrStdin>,
+ /// include <NAME>. If <NAME> is a valid path or '-', hosts with be read from the file or
+ /// stdin, respectivly.
+ #[arg(short, long, id = "HOST")]
+ pub exclude: Vec<IncludeExclude>,
+}
- /// exclude host from output
- #[arg(short, long)]
- pub exclude: Vec<sshr::Host>,
+impl Config {
+ pub fn included(&self) -> std::io::Result<HashSet<String>> {
+ Self::collect_values(&self.include)
+ }
- /// include lines from file, use '-' for stdin
- #[arg(short = 'E', long)]
- pub exclude_file: Vec<FileOrStdin>,
+ pub fn excluded(&self) -> std::io::Result<HashSet<String>> {
+ Self::collect_values(&self.exclude)
+ }
+
+ fn collect_values(values: &[IncludeExclude]) -> std::io::Result<HashSet<String>> {
+ use std::io::BufRead;
+ values.iter().try_fold(HashSet::new(), |mut acc, item| {
+ match item {
+ IncludeExclude::Stdin => {
+ acc.extend(std::io::stdin().lock().lines().map_while(Result::ok))
+ }
+ IncludeExclude::File(filepath) => acc.extend(
+ std::io::BufReader::new(std::fs::File::open(filepath)?)
+ .lines()
+ .map_while(Result::ok),
+ ),
+ IncludeExclude::Item(s) => {
+ acc.insert(s.to_owned());
+ }
+ }
+ Ok(acc)
+ })
+ }
}
static READ_STDIN: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Clone)]
-pub enum FileOrStdin {
+pub enum IncludeExclude {
Stdin,
+ Item(String),
File(PathBuf),
}
-impl FileOrStdin {
- pub fn hosts(self) -> Result<Vec<sshr::Host>, anyhow::Error> {
- use std::io::Read;
- let mut buf = String::new();
- let _ = self.into_reader()?.read_to_string(&mut buf)?;
- buf.lines()
- .map(|s| s.trim_end().parse())
- .collect::<Result<Vec<_>, _>>()
- .map_err(|e| anyhow::format_err!("{e}"))
- }
-
- pub fn into_reader(&self) -> Result<impl std::io::Read, anyhow::Error> {
- let input: Box<dyn std::io::Read + 'static> = match &self {
- Self::Stdin => Box::new(std::io::stdin()),
- Self::File(filepath) => {
- let f = std::fs::File::open(filepath)?;
- Box::new(f)
- }
- };
- Ok(input)
- }
-}
-
-impl FromStr for FileOrStdin {
+impl FromStr for IncludeExclude {
type Err = std::io::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
- match s {
- "-" => {
- if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) {
- return Err(std::io::Error::new(
- std::io::ErrorKind::Other,
- "stdin argument used more than once",
- ));
- }
- READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst);
- Ok(Self::Stdin)
+ if s == "-" {
+ if READ_STDIN.load(std::sync::atomic::Ordering::Acquire) {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "stdin argument used more than once",
+ ));
+ }
+ READ_STDIN.store(true, std::sync::atomic::Ordering::SeqCst);
+ Ok(Self::Stdin)
+ } else {
+ match PathBuf::from_str(s) {
+ Ok(path) if path.exists() => Ok(Self::File(path)),
+ _ => Ok(Self::Item(s.to_owned())),
}
- path => PathBuf::from_str(path)
- .map(Self::File)
- .map_err(|_| unreachable!()),
}
}
}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..9727a81
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,30 @@
+pub type Result<T, E = Error> = std::result::Result<T, E>;
+
+use thiserror::Error;
+
+#[derive(Debug, Error)]
+pub enum Error {
+ #[error("Netlink error: {0:?}")]
+ Netlink(#[from] rtnetlink::Error),
+
+ #[error("IO error: {0}")]
+ IO(#[from] std::io::Error),
+
+ #[error("Send error: {0}")]
+ Send(String),
+
+ #[error("Lock error: {0}")]
+ Lock(String),
+}
+
+impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
+ fn from(err: tokio::sync::mpsc::error::SendError<T>) -> Self {
+ Self::Send(err.to_string())
+ }
+}
+
+impl<T> From<std::sync::PoisonError<T>> for Error {
+ fn from(err: std::sync::PoisonError<T>) -> Self {
+ Self::Lock(err.to_string())
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 9417b71..bf1ff82 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,66 +1,4 @@
-use std::{
- convert::Infallible,
- fmt::Display,
- net::{IpAddr, Ipv4Addr, Ipv6Addr},
- str::FromStr,
-};
+pub use error::{Error, Result};
+pub mod error;
pub mod netlink;
-pub mod tmux;
-
-#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
-pub enum Host {
- IpAddr(IpAddr),
- Hostname(String),
-}
-
-impl Host {
- pub fn resolve(&self) -> std::io::Result<Self> {
- match self {
- Self::IpAddr(ip) => dns_lookup::lookup_addr(ip).map(Self::Hostname),
- Self::Hostname(h) => dns_lookup::lookup_host(h)?
- .first()
- .cloned()
- .map(Self::IpAddr)
- .ok_or(std::io::Error::new(
- std::io::ErrorKind::NotFound,
- "resolution not enabled",
- )),
- }
- }
-}
-
-impl FromStr for Host {
- type Err = Infallible;
-
- fn from_str(s: &str) -> Result<Self, Self::Err> {
- Ok(IpAddr::from_str(s).map_or_else(|_| Self::Hostname(s.to_string()), Host::from))
- }
-}
-
-impl From<IpAddr> for Host {
- fn from(value: IpAddr) -> Self {
- Self::IpAddr(value)
- }
-}
-
-impl From<Ipv4Addr> for Host {
- fn from(value: Ipv4Addr) -> Self {
- IpAddr::from(value).into()
- }
-}
-
-impl From<Ipv6Addr> for Host {
- fn from(value: Ipv6Addr) -> Self {
- IpAddr::from(value).into()
- }
-}
-
-impl Display for Host {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Host::IpAddr(i) => write!(f, "{}", i),
- Host::Hostname(s) => write!(f, "{}", s),
- }
- }
-}
diff --git a/src/main.rs b/src/main.rs
index 5fc8435..7dd9390 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,66 +1,75 @@
use std::{
- collections::HashSet,
io::{stdout, Write},
+ net::IpAddr,
+ str::FromStr,
+ sync::{Arc, Mutex},
};
use clap::Parser;
use config::Config;
-use tokio::sync::mpsc;
+use dns_lookup::{lookup_addr, lookup_host};
+use tokio::{sync::mpsc, task::JoinSet};
mod config;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let config = Config::parse();
+ let mut join_set = JoinSet::new();
+ let excluded = config.excluded()?;
+
+ let cache = Arc::new(Mutex::new(excluded));
let (tx, mut rx) = mpsc::channel(100);
- for host in config
- .include_file
- .into_iter()
- .flat_map(|i| i.hosts())
- .flatten()
- .chain(config.include)
- {
- tx.send(host).await?;
- }
+ for host in config.included()? {
+ match IpAddr::from_str(&host) {
+ Ok(ip) => tx.send(ip).await?,
+ Err(_) => {
+ if cache.lock().unwrap().insert(host.clone()) {
+ let mut stdout = stdout().lock();
+ writeln!(stdout, "{host}")?;
+ }
- sshr::tmux::sessions(tx.clone(), &config.socket).await?;
+ if config.resolve {
+ let cache = cache.clone();
+ join_set.spawn(async move {
+ for ip in lookup_host(&host)? {
+ cache.lock()?.insert(ip.to_string());
+ }
+ Ok(())
+ });
+ }
+ }
+ }
+ }
- tokio::spawn(sshr::netlink::neighbours(tx.clone()));
+ join_set.spawn(sshr::netlink::neighbours(tx.clone()));
drop(tx);
- let mut cache = HashSet::new();
-
- for host in config
- .exclude_file
- .into_iter()
- .flat_map(|i| i.hosts())
- .flatten()
- .chain(config.exclude)
- {
- cache.insert(host);
- }
-
- let mut stdout = stdout();
- while let Some(host) = rx.recv().await {
- if !cache.insert(host.clone()) {
- continue;
- }
+ while let Some(ip_addr) = rx.recv().await {
+ join_set.spawn({
+ let cache = cache.clone();
+ async move {
+ let s = if config.resolve {
+ lookup_addr(&ip_addr).unwrap_or_else(|_| ip_addr.to_string())
+ } else {
+ ip_addr.to_string()
+ };
- let resolved = config.resolve.then_some(host.resolve().ok()).flatten();
-
- if let Some(res) = resolved.clone() {
- if !cache.insert(res) {
- continue;
+ if cache.lock()?.insert(s.clone()) {
+ let mut stdout = stdout().lock();
+ writeln!(stdout, "{s}")?;
+ }
+ Ok(())
}
- }
+ });
+ }
- if let Some(sshr::Host::Hostname(r)) = resolved {
- writeln!(stdout, "{r}")?;
- } else {
- writeln!(stdout, "{host}")?;
+ while let Some(res) = join_set.join_next().await {
+ if let Err(err) = res {
+ eprintln!("{err}")
}
}
diff --git a/src/netlink.rs b/src/netlink.rs
index 550f84e..feaac90 100644
--- a/src/netlink.rs
+++ b/src/netlink.rs
@@ -1,4 +1,6 @@
-use futures::{stream::TryStreamExt, FutureExt};
+use std::net::IpAddr;
+
+use futures::stream::TryStreamExt;
use netlink_packet_route::{
neighbour::{NeighbourAddress, NeighbourAttribute, NeighbourMessage},
route::RouteType,
@@ -6,11 +8,11 @@ use netlink_packet_route::{
use rtnetlink::{new_connection, IpVersion};
use tokio::sync::mpsc::Sender;
-use crate::Host;
+use crate::{Error, Result};
pub struct Netlink;
-pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> {
+pub async fn neighbours(tx: Sender<IpAddr>) -> Result<()> {
let (connection, handle, _) = new_connection().unwrap();
tokio::spawn(connection);
@@ -19,39 +21,23 @@ pub async fn neighbours(tx: Sender<Host>) -> Result<(), rtnetlink::Error> {
.get()
.set_family(IpVersion::V4)
.execute()
- .try_filter_map(|r| async move { Ok(filter(r)) })
- .try_for_each(|host| tx.send(host).then(|_| async { Ok(()) }))
+ .or_else(|res| async { Err(Error::from(res)) })
+ .try_filter_map(|msg| async { Ok(filter(msg)) })
+ .try_for_each(|host| {
+ let tx = tx.clone();
+ async move { tx.send(host).await.map_err(Error::from) }
+ })
.await
}
-pub fn filter(route: NeighbourMessage) -> Option<Host> {
- if route.header.kind != RouteType::Unicast {
+pub fn filter(msg: NeighbourMessage) -> Option<IpAddr> {
+ if msg.header.kind != RouteType::Unicast {
return None;
}
- route.attributes.into_iter().find_map(|attr| match attr {
- NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(Host::from(ip)),
- NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(Host::from(ip)),
+ msg.attributes.into_iter().find_map(|attr| match attr {
+ NeighbourAttribute::Destination(NeighbourAddress::Inet(ip)) => Some(ip.into()),
+ NeighbourAttribute::Destination(NeighbourAddress::Inet6(ip)) => Some(ip.into()),
_ => None,
})
}
-
-#[cfg(test)]
-mod tests {
- use tokio::sync::mpsc;
-
- use super::*;
-
- #[tokio::test]
- async fn test_dump_neighbours() -> Result<(), ()> {
- let (tx, mut rx) = mpsc::channel::<Host>(100);
-
- tokio::spawn(neighbours(tx));
-
- while let Some(res) = rx.recv().await {
- println!("{res}");
- }
-
- Ok(())
- }
-}
diff --git a/src/tmux.rs b/src/tmux.rs
deleted file mode 100644
index e318c07..0000000
--- a/src/tmux.rs
+++ /dev/null
@@ -1,53 +0,0 @@
-use std::{cmp::Reverse, collections::BTreeMap, ffi::OsStr, process::Command};
-
-use tokio::sync::mpsc::Sender;
-
-use crate::Host;
-
-pub async fn sessions<S>(tx: Sender<Host>, socket: &S) -> Result<(), anyhow::Error>
-where
- S: AsRef<OsStr>,
-{
- let stdout = Command::new("tmux")
- .arg("-L")
- .arg(socket)
- .arg("list-sessions")
- .arg("-F")
- .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{s/_/./:session_name}")
- .output()?
- .stdout;
-
- let mut btree_map: BTreeMap<Reverse<usize>, String> = std::str::from_utf8(&stdout)?
- .lines()
- .flat_map(|s| {
- let (t, s) = s.split_once(':')?;
- Some((Reverse(t.parse().ok()?), s.to_string()))
- })
- .collect();
-
- let stdout = Command::new("tmux")
- .arg("-L")
- .arg("default")
- .arg("list-sessions")
- .arg("-F")
- .arg("#{?session_last_attached,#{session_last_attached},#{session_created}}:#{host}")
- .output()?
- .stdout;
-
- if let Some((t, s)) = std::str::from_utf8(&stdout)?
- .lines()
- .flat_map(|s| {
- let (t, s) = s.split_once(':')?;
- Some((t.parse().ok()?, s.to_string()))
- })
- .max_by_key(|t| t.0)
- {
- btree_map.insert(Reverse(t), s);
- }
-
- for name in btree_map.into_values() {
- tx.send(name.parse()?).await?;
- }
-
- Ok(())
-}