diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 76 | ||||
-rw-r--r-- | src/history.rs | 69 | ||||
-rw-r--r-- | src/lib.rs | 18 | ||||
-rw-r--r-- | src/localhost.rs | 25 | ||||
-rw-r--r-- | src/main.rs | 73 | ||||
-rw-r--r-- | src/session.rs | 102 | ||||
-rw-r--r-- | src/ssh.rs | 2 | ||||
-rw-r--r-- | src/stdio.rs | 46 | ||||
-rw-r--r-- | src/tmux.rs | 80 | ||||
-rw-r--r-- | src/unix.rs | 26 |
10 files changed, 301 insertions, 216 deletions
diff --git a/src/config.rs b/src/config.rs index 63dfbc7..258255b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,30 +1,78 @@ -use std::path::PathBuf; - use clap::{Args, Parser}; use tracing::{metadata::LevelFilter, Level}; +use crate::{history, session, stdio, tmux}; + #[derive(Debug, Clone, Parser)] pub struct Config { - /// Update the history file from the current sessions - #[arg(short, long)] - pub update: bool, + #[command(flatten)] + pub enabled: Flags, - /// tmux socket-name, equivelent to `tmux -L <socket-name>` - #[arg(short = 'L', long, default_value = "ssh")] - pub socket_name: String, + #[command(flatten)] + pub stdio: stdio::Config, - /// Name of host to exclude - #[arg(short, long)] - pub exclude: Vec<String>, + #[command(flatten)] + pub sessions: session::Config, + + #[command(flatten)] + pub history: history::History, - /// path to history file [default: $XDG_DATA_HOME/sshr/history] - #[arg(short = 'f', long)] - pub history_file: Option<PathBuf>, + #[command(flatten)] + pub tmux: tmux::Tmux, #[command(flatten)] pub verbosity: Verbosity, } +#[derive(Debug, Clone, Args)] +pub struct Flags { + /// Include localhost + #[arg(short, long)] + pub localhost: bool, + + /// Include hosts from tmux session names + #[arg(short, long)] + pub tmux: bool, + + /// Include hosts from history file + #[arg(short = 'H', long)] + pub history: bool, + + /// Include hosts from the ssh `known_hosts` + #[arg(short, long)] + pub ssh: bool, + + /// Include hosts from `/etc/hosts` + #[arg(short = 'o', long)] + pub hosts: bool, + + /// Alias to include all host sources + #[arg(short, long)] + pub all: bool, +} + +impl Flags { + pub fn localhost(&self) -> bool { + self.all || self.localhost + } + + pub fn tmux(&self) -> bool { + self.all || self.tmux + } + + pub fn history(&self) -> bool { + self.all || self.history + } + + pub fn ssh(&self) -> bool { + self.all || self.ssh + } + + pub fn hosts(&self) -> bool { + self.all || self.hosts + } +} + #[derive(Debug, Default, Clone, Args)] pub struct Verbosity { /// Print additional information per occurrence. diff --git a/src/history.rs b/src/history.rs index 863be8e..0394e3d 100644 --- a/src/history.rs +++ b/src/history.rs @@ -1,60 +1,77 @@ use std::{ fs::File, - io::{BufRead, BufReader, Write}, + io::{BufRead, BufReader, ErrorKind}, path::PathBuf, }; +use clap::Args; use directories::ProjectDirs; -use crate::Session; +use crate::{session::SessionWriter, Session}; -pub use error::Error; - -mod error; - -#[derive(Debug)] +#[derive(Debug, Clone, Args)] +#[group(skip)] pub struct History { - pub file: File, - pub entries: Vec<Session>, + /// Update the history file from the current sessions + #[arg(short, long)] + pub update: bool, + + /// path to history file [default: $XDG_DATA_HOME/sshr/history] + #[arg(short = 'f', long = "history_file")] + path: Option<PathBuf>, } impl History { - pub fn open(history_file: PathBuf) -> Result<Self, std::io::Error> { - let file = File::options().write(true).open(&history_file)?; + pub fn new(History { update, path }: History) -> Self { + Self { + path: path.or_else(History::default_path), + update, + } + } + + pub fn read(&self) -> Result<Vec<Session>, std::io::Error> { + let Some(path) = &self.path() else { + tracing::warn!(?self.path, "History file does not exist"); + return Ok(Vec::new()); + }; - let entries = BufReader::new(File::open(history_file)?) + let sessions = BufReader::new(File::open(path)?) .lines() .flatten() .flat_map(|item| ron::from_str(&item)) .collect(); - Ok(Self { file, entries }) + Ok(sessions) } - pub fn default_path() -> Option<PathBuf> { + fn default_path() -> Option<PathBuf> { ProjectDirs::from("", "", env!("CARGO_CRATE_NAME"))? .state_dir()? .join("history") .into() } -} -impl IntoIterator for History { - type Item = Session; + pub fn path(&self) -> Option<PathBuf> { + self.path.clone().or_else(History::default_path) + } +} - type IntoIter = std::vec::IntoIter<Self::Item>; +impl SessionWriter for History { + type Writer = File; + type Error = ron::Error; - fn into_iter(self) -> Self::IntoIter { - self.entries.into_iter() + fn format(&self, session: &Session) -> Result<String, Self::Error> { + ron::to_string(session) } -} -impl Write for History { - fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { - self.file.write(buf) + fn filter(&self, session: &Session) -> bool { + self.update && !matches!(session.state, crate::State::Discovered) } - fn flush(&mut self) -> std::io::Result<()> { - self.file.flush() + fn writer(&self) -> Result<Self::Writer, std::io::Error> { + match &self.path { + Some(path) => File::create(path), + None => Err(std::io::Error::from(ErrorKind::NotFound)), + } } } @@ -1,13 +1,19 @@ -pub use config::Config; -pub use history::History; -pub use session::{Session, Sessions, State}; -pub use ssh::KnownHosts; -pub use tmux::Tmux; -pub use unix::Hosts; +pub use crate::{ + config::Config, + history::History, + localhost::HostName, + session::{Session, Sessions, State}, + ssh::KnownHosts, + stdio::Stdout, + tmux::Tmux, + unix::Hosts, +}; mod config; mod history; +mod localhost; mod session; mod ssh; +mod stdio; mod tmux; mod unix; diff --git a/src/localhost.rs b/src/localhost.rs new file mode 100644 index 0000000..45b14de --- /dev/null +++ b/src/localhost.rs @@ -0,0 +1,25 @@ +use std::ffi::OsString; + +use crate::{Session, State}; + +pub struct HostName(OsString); + +impl HostName { + pub fn get() -> Result<Self, std::io::Error> { + hostname::get().map(Self) + } +} + +impl IntoIterator for HostName { + type Item = Session; + + type IntoIter = std::option::IntoIter<Session>; + + fn into_iter(self) -> Self::IntoIter { + Some(Session { + name: self.0.to_string_lossy().into(), + state: State::LocalHost, + }) + .into_iter() + } +} diff --git a/src/main.rs b/src/main.rs index eb62d52..523ee82 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,62 +1,53 @@ -use std::io::ErrorKind; - +use anyhow::Context; use clap::Parser; -use sshr::{Config, History, Hosts, KnownHosts, Sessions, Tmux}; +use sshr::{Config, HostName, Hosts, KnownHosts, Sessions, Stdout}; fn main() -> anyhow::Result<()> { - let mut config = Config::parse(); + let config = Config::parse(); tracing_subscriber::fmt::fmt() .with_max_level(&config.verbosity) .without_time() .init(); - config.history_file = config.history_file.or_else(History::default_path); - - let history = match config.history_file { - Some(path) => History::open(path), - None => Err(std::io::Error::from(ErrorKind::NotFound)), - }; - - let tmux = Tmux::new(config.socket_name); + let mut sessions = Sessions::new(config.sessions); - let mut sessions = Sessions::new(config.exclude); + if config.enabled.localhost() { + let hostname = HostName::get().context("Failed to get hostname of localhost")?; + sessions.extend(hostname); + } - match tmux.host() { - Ok(p) => sessions.add(p), - Err(err) => tracing::warn!(?err, "Failed to get tmux host"), - }; + if config.enabled.tmux() { + let tmux_sessions = config.tmux.list().context("Failed to list tmux sessions")?; + sessions.extend(tmux_sessions); + } - match tmux.list(None) { - Ok(p) => sessions.extend(p), - Err(err) => tracing::warn!(?err, "Failed to list tmux sessions"), - }; + if config.enabled.ssh() { + let known_hosts = KnownHosts::open().context("Failed to read KnownHost file")?; + sessions.extend(known_hosts); + } - match history { - Ok(p) => { - sessions.extend(p.entries); + if config.enabled.hosts() { + let hosts = Hosts::open().context("Failed to read /etc/hosts")?; + sessions.extend(hosts); + } - if config.update { - sessions.write_into(p.file)?; + if config.enabled.history() { + match config.history.read() { + Ok(h) => sessions.extend(h), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + tracing::warn!("Skipping non-existant history file") } + Err(err) => return Err(err).context("Failed to read history file"), } - Err(err) => tracing::warn!(?err, "Failed to open history file"), - }; - - match KnownHosts::open() { - Ok(p) => sessions.extend(p), - Err(err) => tracing::warn!(?err, "Failed to read KnownHost file"), - }; - - match Hosts::open() { - Ok(p) => sessions.extend(p), - Err(err) => tracing::warn!(?err, "Failed to read /etc/hosts"), - }; + } - for session in sessions.sorted() { - println!("{session}"); + if config.history.update { + sessions.write_sessions(config.history)?; } - Ok(()) + sessions + .write_sessions(Stdout::new(config.stdio)) + .context("Failed to write to stdout") } diff --git a/src/session.rs b/src/session.rs index a7ed261..aad299a 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,40 +1,59 @@ use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::{hash_map::Entry, HashMap}, fmt::Display, io::{BufWriter, Write}, iter::IntoIterator, time::Duration, }; +use clap::Args; use serde::{Deserialize, Serialize}; +pub trait SessionWriter { + type Error: Display; + type Writer: Write; + + fn writer(&self) -> Result<Self::Writer, std::io::Error>; + fn format(&self, session: &Session) -> Result<String, Self::Error>; + fn filter(&self, session: &Session) -> bool; +} + +#[derive(Debug, Clone, Args)] +#[group(skip)] +pub struct Config { + /// Enable sorting + #[arg(long, default_value_t = true)] + pub sort: bool, +} + #[derive(Debug, Default)] pub struct Sessions { inner: HashMap<String, State>, - exclude: Vec<String>, + sort: bool, } impl Sessions { - pub fn new(exclude: Vec<String>) -> Self { + pub fn new(Config { sort }: Config) -> Self { Self { - exclude, + sort, ..Default::default() } } - pub fn sorted(self) -> Vec<Session> { - let mut sessions: Vec<Session> = self.into_iter().map(Session::from).collect(); - sessions.sort(); - sessions - } + pub fn write_sessions<W: SessionWriter>(&self, writer: W) -> std::io::Result<()> { + let mut buf_writer = BufWriter::new(writer.writer()?); + + let mut sessions: Vec<Session> = self.inner.iter().map(Session::from).collect(); - pub fn write_into<W: Write>(&self, writer: W) -> std::io::Result<()> { - let mut buf_writer = BufWriter::new(writer); + if self.sort { + sessions.sort(); + } - for session in self.inner.iter().map(Session::from) { - match ron::to_string(&session) { - Ok(ser) => writeln!(buf_writer, "{ser}")?, - Err(err) => tracing::warn!(%err, "Failed to serialize session"), + for session in sessions { + match writer.format(&session) { + Ok(fmt) if writer.filter(&session) => writeln!(buf_writer, "{fmt}")?, + Err(err) => tracing::warn!(%err, "Failed to format session"), + _ => tracing::debug!(%session, "Skipping filtered session"), } } @@ -45,11 +64,6 @@ impl Sessions { let span = tracing::trace_span!("Entry", ?item); let _guard = span.enter(); - if self.exclude.contains(&item.name) { - tracing::debug!(item.name, "Skipping excluded item"); - return; - } - match self.inner.entry(item.name) { Entry::Occupied(mut occupied) if &item.state > occupied.get() => { tracing::trace!(?occupied, new_value=?item.state, "New entry is more recent, replacing"); @@ -66,20 +80,6 @@ impl Sessions { } } -impl IntoIterator for Sessions { - type Item = Session; - - type IntoIter = std::collections::hash_set::IntoIter<Self::Item>; - - fn into_iter(self) -> Self::IntoIter { - self.inner - .into_iter() - .map(Into::into) - .collect::<HashSet<Session>>() - .into_iter() - } -} - impl Extend<Session> for Sessions { fn extend<T: IntoIterator<Item = Session>>(&mut self, iter: T) { for item in iter { @@ -98,12 +98,6 @@ pub enum State { LocalHost, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct Session { - pub state: State, - pub name: String, -} - mod epoch_timestamp { use std::time::Duration; @@ -124,30 +118,34 @@ mod epoch_timestamp { } } -impl Display for Session { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) - } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct Session { + pub state: State, + pub name: String, } -impl From<&str> for Session { - fn from(value: &str) -> Self { +impl Session { + pub fn discover(name: impl Into<String>) -> Self { Self { + name: name.into(), state: State::Discovered, - name: value.to_owned(), } } -} -impl From<String> for Session { - fn from(name: String) -> Self { + pub fn localhost(name: impl Into<String>) -> Self { Self { - state: State::Discovered, - name, + name: name.into(), + state: State::LocalHost, } } } +impl Display for Session { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + impl From<(String, State)> for Session { fn from((name, state): (String, State)) -> Self { Self { state, name } @@ -23,7 +23,7 @@ impl KnownHosts { .flatten() .take_while(|s| !s.starts_with('#')) .filter_map(|l| l.split_whitespace().next().map(str::to_owned)) - .map(Session::from) + .map(Session::discover) .collect(); Ok(Self(inner)) diff --git a/src/stdio.rs b/src/stdio.rs new file mode 100644 index 0000000..84319a1 --- /dev/null +++ b/src/stdio.rs @@ -0,0 +1,46 @@ +use std::convert::Infallible; + +use clap::Args; + +use crate::{session::SessionWriter, Session}; + +#[derive(Debug, Clone, Args)] +#[group(skip)] +pub struct Config { + /// Exclude item from output + #[arg(short, long)] + pub exclude: Vec<String>, +} + +impl Config { + pub fn new(Config { exclude }: Config) -> Self { + Self { exclude } + } +} + +pub struct Stdout { + exclude: Vec<String>, +} + +impl Stdout { + pub fn new(Config { exclude }: Config) -> Self { + Self { exclude } + } +} + +impl SessionWriter for Stdout { + type Writer = std::io::Stdout; + type Error = Infallible; + + fn format(&self, session: &Session) -> Result<String, Self::Error> { + Ok(session.name.to_string()) + } + + fn filter(&self, session: &Session) -> bool { + !self.exclude.contains(&session.name) + } + + fn writer(&self) -> Result<Self::Writer, std::io::Error> { + Ok(std::io::stdout()) + } +} diff --git a/src/tmux.rs b/src/tmux.rs index 88f521f..e1fa3a6 100644 --- a/src/tmux.rs +++ b/src/tmux.rs @@ -2,32 +2,27 @@ use std::process::Command; use crate::Session; +use clap::Args; pub use error::Error; mod error; -#[derive(Debug)] +#[derive(Debug, Clone, Args)] +#[group(skip)] pub struct Tmux { - socket_name: String, + /// tmux socket-name, equivelent to `tmux -L <socket-name>` + #[arg(short = 'L', long = "tmux_socket", default_value = "ssh")] + pub socket: String, } impl Tmux { const SESSION_FORMAT: &str = r##"Session(name: "#S", state: #{?session_last_attached,Attached(#{session_last_attached}),Created(#{session_created})})"##; - pub fn new(socket_name: String) -> Self { - Self { socket_name } - } - - pub fn list(&self, name: Option<String>) -> Result<Vec<Session>, Error> { - let filter = name - .map(|s| vec!["-f".into(), format!("#{{==:#S,{s}}}")]) - .unwrap_or_default(); - + pub fn list(&self) -> Result<Vec<Session>, Error> { let stdout = Command::new("tmux") .arg("-L") - .arg(&self.socket_name) + .arg(&self.socket) .arg("list-sessions") - .args(filter) .arg("-F") .arg(Self::SESSION_FORMAT) .output()? @@ -46,31 +41,13 @@ impl Tmux { Ok(sessions) } - - pub fn host(&self) -> Result<Session, Error> { - let stdout = Command::new("tmux") - .arg("-L") - .arg(&self.socket_name) - .arg("display") - .arg("-p") - .arg("#h") - .output()? - .stdout; - - let name = std::str::from_utf8(&stdout)?.trim().into(); - - Ok(Session { - state: crate::State::LocalHost, - name, - }) - } } #[cfg(test)] mod tests { use super::*; - const SOCKET_NAME: &str = "test"; + const SOCKET: &str = "test"; #[test] fn test_tmux_list() -> Result<(), Error> { @@ -79,52 +56,25 @@ mod tests { for name in names.iter().cloned() { Command::new("tmux") .arg("-L") - .arg(SOCKET_NAME) + .arg(SOCKET) .arg("new-session") .arg("-ds") .arg(name) .status()?; } - let tmux = Tmux::new(SOCKET_NAME.to_owned()); - let sessions: Vec<_> = tmux.list(None)?.into_iter().map(|s| s.name).collect(); - - Command::new("tmux") - .arg("-L") - .arg(SOCKET_NAME) - .arg("kill-server") - .status()?; - - assert_eq!(names, sessions); - - Ok(()) - } - - #[test] - fn test_tmux_host() -> Result<(), Error> { - Command::new("tmux") - .arg("-L") - .arg(SOCKET_NAME) - .arg("new-session") - .arg("-d") - .status()?; - - let name = hostname::get()?.to_string_lossy().into(); - let expected_session = Session { - state: crate::State::LocalHost, - name, + let tmux = Tmux { + socket: SOCKET.to_owned(), }; - - let tmux = Tmux::new(SOCKET_NAME.to_owned()); - let session = tmux.host()?; + let sessions: Vec<_> = tmux.list()?.into_iter().map(|s| s.name).collect(); Command::new("tmux") .arg("-L") - .arg(SOCKET_NAME) + .arg(SOCKET) .arg("kill-server") .status()?; - assert_eq!(expected_session, session); + assert_eq!(names, sessions); Ok(()) } diff --git a/src/unix.rs b/src/unix.rs index 402e476..d549b63 100644 --- a/src/unix.rs +++ b/src/unix.rs @@ -9,20 +9,24 @@ pub struct Hosts(Vec<Session>); impl Hosts { pub fn open() -> Result<Self, Error> { - let buf_reader = BufReader::new(File::open("/etc/hosts")?); - let inner: Vec<Session> = buf_reader + File::open("/etc/hosts").map(Self::parse_file).map(Self) + } + + fn parse_file(file: File) -> Vec<Session> { + BufReader::new(file) .lines() .flatten() + .filter_map(Self::parse_line) + .collect() + } + + fn parse_line(line: String) -> Option<Session> { + line.split_whitespace() .take_while(|s| !s.starts_with('#')) - .flat_map(|l| { - l.split_whitespace() - .skip(1) - .take_while(|s| !s.starts_with('#')) - .map(Session::from) - .collect::<Vec<_>>() - }) - .collect(); - Ok(Self(inner)) + .last() + // Skip BOM + .filter(|&s| s != "\u{feff}") + .map(Session::discover) } } |