summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorToby Vincent <tobyv13@gmail.com>2023-04-03 19:12:14 -0500
committerToby Vincent <tobyv13@gmail.com>2023-04-03 19:12:14 -0500
commitc94960b49f463b3c4a7da9fb1b6b2c122f7dd125 (patch)
tree6be101dd21b61775313e04b6c695629b183f48b4
parent8a6631053a48a64f0c3b21ec0d92b6b687de9638 (diff)
refactor: clean up logic and impl trait for writing sessions
-rw-r--r--Cargo.toml2
-rw-r--r--src/config.rs76
-rw-r--r--src/history.rs69
-rw-r--r--src/lib.rs18
-rw-r--r--src/localhost.rs25
-rw-r--r--src/main.rs73
-rw-r--r--src/session.rs102
-rw-r--r--src/ssh.rs2
-rw-r--r--src/stdio.rs46
-rw-r--r--src/tmux.rs80
-rw-r--r--src/unix.rs26
11 files changed, 301 insertions, 218 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 5d617ac..9da5cb8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,4 @@ tracing = "0.1.37"
thiserror = "1.0.40"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
hostfile = "0.2.0"
-
-[dev-dependencies]
hostname = "0.3.1"
diff --git a/src/config.rs b/src/config.rs
index 63dfbc7..258255b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,30 +1,78 @@
-use std::path::PathBuf;
-
use clap::{Args, Parser};
use tracing::{metadata::LevelFilter, Level};
+use crate::{history, session, stdio, tmux};
+
#[derive(Debug, Clone, Parser)]
pub struct Config {
- /// Update the history file from the current sessions
- #[arg(short, long)]
- pub update: bool,
+ #[command(flatten)]
+ pub enabled: Flags,
- /// tmux socket-name, equivelent to `tmux -L <socket-name>`
- #[arg(short = 'L', long, default_value = "ssh")]
- pub socket_name: String,
+ #[command(flatten)]
+ pub stdio: stdio::Config,
- /// Name of host to exclude
- #[arg(short, long)]
- pub exclude: Vec<String>,
+ #[command(flatten)]
+ pub sessions: session::Config,
+
+ #[command(flatten)]
+ pub history: history::History,
- /// path to history file [default: $XDG_DATA_HOME/sshr/history]
- #[arg(short = 'f', long)]
- pub history_file: Option<PathBuf>,
+ #[command(flatten)]
+ pub tmux: tmux::Tmux,
#[command(flatten)]
pub verbosity: Verbosity,
}
+#[derive(Debug, Clone, Args)]
+pub struct Flags {
+ /// Include localhost
+ #[arg(short, long)]
+ pub localhost: bool,
+
+ /// Include hosts from tmux session names
+ #[arg(short, long)]
+ pub tmux: bool,
+
+ /// Include hosts from history file
+ #[arg(short = 'H', long)]
+ pub history: bool,
+
+ /// Include hosts from the ssh `known_hosts`
+ #[arg(short, long)]
+ pub ssh: bool,
+
+ /// Include hosts from `/etc/hosts`
+ #[arg(short = 'o', long)]
+ pub hosts: bool,
+
+ /// Alias to include all host sources
+ #[arg(short, long)]
+ pub all: bool,
+}
+
+impl Flags {
+ pub fn localhost(&self) -> bool {
+ self.all || self.localhost
+ }
+
+ pub fn tmux(&self) -> bool {
+ self.all || self.tmux
+ }
+
+ pub fn history(&self) -> bool {
+ self.all || self.history
+ }
+
+ pub fn ssh(&self) -> bool {
+ self.all || self.ssh
+ }
+
+ pub fn hosts(&self) -> bool {
+ self.all || self.hosts
+ }
+}
+
#[derive(Debug, Default, Clone, Args)]
pub struct Verbosity {
/// Print additional information per occurrence.
diff --git a/src/history.rs b/src/history.rs
index 863be8e..0394e3d 100644
--- a/src/history.rs
+++ b/src/history.rs
@@ -1,60 +1,77 @@
use std::{
fs::File,
- io::{BufRead, BufReader, Write},
+ io::{BufRead, BufReader, ErrorKind},
path::PathBuf,
};
+use clap::Args;
use directories::ProjectDirs;
-use crate::Session;
+use crate::{session::SessionWriter, Session};
-pub use error::Error;
-
-mod error;
-
-#[derive(Debug)]
+#[derive(Debug, Clone, Args)]
+#[group(skip)]
pub struct History {
- pub file: File,
- pub entries: Vec<Session>,
+ /// Update the history file from the current sessions
+ #[arg(short, long)]
+ pub update: bool,
+
+ /// path to history file [default: $XDG_DATA_HOME/sshr/history]
+ #[arg(short = 'f', long = "history_file")]
+ path: Option<PathBuf>,
}
impl History {
- pub fn open(history_file: PathBuf) -> Result<Self, std::io::Error> {
- let file = File::options().write(true).open(&history_file)?;
+ pub fn new(History { update, path }: History) -> Self {
+ Self {
+ path: path.or_else(History::default_path),
+ update,
+ }
+ }
+
+ pub fn read(&self) -> Result<Vec<Session>, std::io::Error> {
+ let Some(path) = &self.path() else {
+ tracing::warn!(?self.path, "History file does not exist");
+ return Ok(Vec::new());
+ };
- let entries = BufReader::new(File::open(history_file)?)
+ let sessions = BufReader::new(File::open(path)?)
.lines()
.flatten()
.flat_map(|item| ron::from_str(&item))
.collect();
- Ok(Self { file, entries })
+ Ok(sessions)
}
- pub fn default_path() -> Option<PathBuf> {
+ fn default_path() -> Option<PathBuf> {
ProjectDirs::from("", "", env!("CARGO_CRATE_NAME"))?
.state_dir()?
.join("history")
.into()
}
-}
-impl IntoIterator for History {
- type Item = Session;
+ pub fn path(&self) -> Option<PathBuf> {
+ self.path.clone().or_else(History::default_path)
+ }
+}
- type IntoIter = std::vec::IntoIter<Self::Item>;
+impl SessionWriter for History {
+ type Writer = File;
+ type Error = ron::Error;
- fn into_iter(self) -> Self::IntoIter {
- self.entries.into_iter()
+ fn format(&self, session: &Session) -> Result<String, Self::Error> {
+ ron::to_string(session)
}
-}
-impl Write for History {
- fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
- self.file.write(buf)
+ fn filter(&self, session: &Session) -> bool {
+ self.update && !matches!(session.state, crate::State::Discovered)
}
- fn flush(&mut self) -> std::io::Result<()> {
- self.file.flush()
+ fn writer(&self) -> Result<Self::Writer, std::io::Error> {
+ match &self.path {
+ Some(path) => File::create(path),
+ None => Err(std::io::Error::from(ErrorKind::NotFound)),
+ }
}
}
diff --git a/src/lib.rs b/src/lib.rs
index c749b20..ebf9c8d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,13 +1,19 @@
-pub use config::Config;
-pub use history::History;
-pub use session::{Session, Sessions, State};
-pub use ssh::KnownHosts;
-pub use tmux::Tmux;
-pub use unix::Hosts;
+pub use crate::{
+ config::Config,
+ history::History,
+ localhost::HostName,
+ session::{Session, Sessions, State},
+ ssh::KnownHosts,
+ stdio::Stdout,
+ tmux::Tmux,
+ unix::Hosts,
+};
mod config;
mod history;
+mod localhost;
mod session;
mod ssh;
+mod stdio;
mod tmux;
mod unix;
diff --git a/src/localhost.rs b/src/localhost.rs
new file mode 100644
index 0000000..45b14de
--- /dev/null
+++ b/src/localhost.rs
@@ -0,0 +1,25 @@
+use std::ffi::OsString;
+
+use crate::{Session, State};
+
+pub struct HostName(OsString);
+
+impl HostName {
+ pub fn get() -> Result<Self, std::io::Error> {
+ hostname::get().map(Self)
+ }
+}
+
+impl IntoIterator for HostName {
+ type Item = Session;
+
+ type IntoIter = std::option::IntoIter<Session>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ Some(Session {
+ name: self.0.to_string_lossy().into(),
+ state: State::LocalHost,
+ })
+ .into_iter()
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index eb62d52..523ee82 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,62 +1,53 @@
-use std::io::ErrorKind;
-
+use anyhow::Context;
use clap::Parser;
-use sshr::{Config, History, Hosts, KnownHosts, Sessions, Tmux};
+use sshr::{Config, HostName, Hosts, KnownHosts, Sessions, Stdout};
fn main() -> anyhow::Result<()> {
- let mut config = Config::parse();
+ let config = Config::parse();
tracing_subscriber::fmt::fmt()
.with_max_level(&config.verbosity)
.without_time()
.init();
- config.history_file = config.history_file.or_else(History::default_path);
-
- let history = match config.history_file {
- Some(path) => History::open(path),
- None => Err(std::io::Error::from(ErrorKind::NotFound)),
- };
-
- let tmux = Tmux::new(config.socket_name);
+ let mut sessions = Sessions::new(config.sessions);
- let mut sessions = Sessions::new(config.exclude);
+ if config.enabled.localhost() {
+ let hostname = HostName::get().context("Failed to get hostname of localhost")?;
+ sessions.extend(hostname);
+ }
- match tmux.host() {
- Ok(p) => sessions.add(p),
- Err(err) => tracing::warn!(?err, "Failed to get tmux host"),
- };
+ if config.enabled.tmux() {
+ let tmux_sessions = config.tmux.list().context("Failed to list tmux sessions")?;
+ sessions.extend(tmux_sessions);
+ }
- match tmux.list(None) {
- Ok(p) => sessions.extend(p),
- Err(err) => tracing::warn!(?err, "Failed to list tmux sessions"),
- };
+ if config.enabled.ssh() {
+ let known_hosts = KnownHosts::open().context("Failed to read KnownHost file")?;
+ sessions.extend(known_hosts);
+ }
- match history {
- Ok(p) => {
- sessions.extend(p.entries);
+ if config.enabled.hosts() {
+ let hosts = Hosts::open().context("Failed to read /etc/hosts")?;
+ sessions.extend(hosts);
+ }
- if config.update {
- sessions.write_into(p.file)?;
+ if config.enabled.history() {
+ match config.history.read() {
+ Ok(h) => sessions.extend(h),
+ Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
+ tracing::warn!("Skipping non-existant history file")
}
+ Err(err) => return Err(err).context("Failed to read history file"),
}
- Err(err) => tracing::warn!(?err, "Failed to open history file"),
- };
-
- match KnownHosts::open() {
- Ok(p) => sessions.extend(p),
- Err(err) => tracing::warn!(?err, "Failed to read KnownHost file"),
- };
-
- match Hosts::open() {
- Ok(p) => sessions.extend(p),
- Err(err) => tracing::warn!(?err, "Failed to read /etc/hosts"),
- };
+ }
- for session in sessions.sorted() {
- println!("{session}");
+ if config.history.update {
+ sessions.write_sessions(config.history)?;
}
- Ok(())
+ sessions
+ .write_sessions(Stdout::new(config.stdio))
+ .context("Failed to write to stdout")
}
diff --git a/src/session.rs b/src/session.rs
index a7ed261..aad299a 100644
--- a/src/session.rs
+++ b/src/session.rs
@@ -1,40 +1,59 @@
use std::{
- collections::{hash_map::Entry, HashMap, HashSet},
+ collections::{hash_map::Entry, HashMap},
fmt::Display,
io::{BufWriter, Write},
iter::IntoIterator,
time::Duration,
};
+use clap::Args;
use serde::{Deserialize, Serialize};
+pub trait SessionWriter {
+ type Error: Display;
+ type Writer: Write;
+
+ fn writer(&self) -> Result<Self::Writer, std::io::Error>;
+ fn format(&self, session: &Session) -> Result<String, Self::Error>;
+ fn filter(&self, session: &Session) -> bool;
+}
+
+#[derive(Debug, Clone, Args)]
+#[group(skip)]
+pub struct Config {
+ /// Enable sorting
+ #[arg(long, default_value_t = true)]
+ pub sort: bool,
+}
+
#[derive(Debug, Default)]
pub struct Sessions {
inner: HashMap<String, State>,
- exclude: Vec<String>,
+ sort: bool,
}
impl Sessions {
- pub fn new(exclude: Vec<String>) -> Self {
+ pub fn new(Config { sort }: Config) -> Self {
Self {
- exclude,
+ sort,
..Default::default()
}
}
- pub fn sorted(self) -> Vec<Session> {
- let mut sessions: Vec<Session> = self.into_iter().map(Session::from).collect();
- sessions.sort();
- sessions
- }
+ pub fn write_sessions<W: SessionWriter>(&self, writer: W) -> std::io::Result<()> {
+ let mut buf_writer = BufWriter::new(writer.writer()?);
+
+ let mut sessions: Vec<Session> = self.inner.iter().map(Session::from).collect();
- pub fn write_into<W: Write>(&self, writer: W) -> std::io::Result<()> {
- let mut buf_writer = BufWriter::new(writer);
+ if self.sort {
+ sessions.sort();
+ }
- for session in self.inner.iter().map(Session::from) {
- match ron::to_string(&session) {
- Ok(ser) => writeln!(buf_writer, "{ser}")?,
- Err(err) => tracing::warn!(%err, "Failed to serialize session"),
+ for session in sessions {
+ match writer.format(&session) {
+ Ok(fmt) if writer.filter(&session) => writeln!(buf_writer, "{fmt}")?,
+ Err(err) => tracing::warn!(%err, "Failed to format session"),
+ _ => tracing::debug!(%session, "Skipping filtered session"),
}
}
@@ -45,11 +64,6 @@ impl Sessions {
let span = tracing::trace_span!("Entry", ?item);
let _guard = span.enter();
- if self.exclude.contains(&item.name) {
- tracing::debug!(item.name, "Skipping excluded item");
- return;
- }
-
match self.inner.entry(item.name) {
Entry::Occupied(mut occupied) if &item.state > occupied.get() => {
tracing::trace!(?occupied, new_value=?item.state, "New entry is more recent, replacing");
@@ -66,20 +80,6 @@ impl Sessions {
}
}
-impl IntoIterator for Sessions {
- type Item = Session;
-
- type IntoIter = std::collections::hash_set::IntoIter<Self::Item>;
-
- fn into_iter(self) -> Self::IntoIter {
- self.inner
- .into_iter()
- .map(Into::into)
- .collect::<HashSet<Session>>()
- .into_iter()
- }
-}
-
impl Extend<Session> for Sessions {
fn extend<T: IntoIterator<Item = Session>>(&mut self, iter: T) {
for item in iter {
@@ -98,12 +98,6 @@ pub enum State {
LocalHost,
}
-#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
-pub struct Session {
- pub state: State,
- pub name: String,
-}
-
mod epoch_timestamp {
use std::time::Duration;
@@ -124,30 +118,34 @@ mod epoch_timestamp {
}
}
-impl Display for Session {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.name)
- }
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub struct Session {
+ pub state: State,
+ pub name: String,
}
-impl From<&str> for Session {
- fn from(value: &str) -> Self {
+impl Session {
+ pub fn discover(name: impl Into<String>) -> Self {
Self {
+ name: name.into(),
state: State::Discovered,
- name: value.to_owned(),
}
}
-}
-impl From<String> for Session {
- fn from(name: String) -> Self {
+ pub fn localhost(name: impl Into<String>) -> Self {
Self {
- state: State::Discovered,
- name,
+ name: name.into(),
+ state: State::LocalHost,
}
}
}
+impl Display for Session {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.name)
+ }
+}
+
impl From<(String, State)> for Session {
fn from((name, state): (String, State)) -> Self {
Self { state, name }
diff --git a/src/ssh.rs b/src/ssh.rs
index 3508a36..606509a 100644
--- a/src/ssh.rs
+++ b/src/ssh.rs
@@ -23,7 +23,7 @@ impl KnownHosts {
.flatten()
.take_while(|s| !s.starts_with('#'))
.filter_map(|l| l.split_whitespace().next().map(str::to_owned))
- .map(Session::from)
+ .map(Session::discover)
.collect();
Ok(Self(inner))
diff --git a/src/stdio.rs b/src/stdio.rs
new file mode 100644
index 0000000..84319a1
--- /dev/null
+++ b/src/stdio.rs
@@ -0,0 +1,46 @@
+use std::convert::Infallible;
+
+use clap::Args;
+
+use crate::{session::SessionWriter, Session};
+
+#[derive(Debug, Clone, Args)]
+#[group(skip)]
+pub struct Config {
+ /// Exclude item from output
+ #[arg(short, long)]
+ pub exclude: Vec<String>,
+}
+
+impl Config {
+ pub fn new(Config { exclude }: Config) -> Self {
+ Self { exclude }
+ }
+}
+
+pub struct Stdout {
+ exclude: Vec<String>,
+}
+
+impl Stdout {
+ pub fn new(Config { exclude }: Config) -> Self {
+ Self { exclude }
+ }
+}
+
+impl SessionWriter for Stdout {
+ type Writer = std::io::Stdout;
+ type Error = Infallible;
+
+ fn format(&self, session: &Session) -> Result<String, Self::Error> {
+ Ok(session.name.to_string())
+ }
+
+ fn filter(&self, session: &Session) -> bool {
+ !self.exclude.contains(&session.name)
+ }
+
+ fn writer(&self) -> Result<Self::Writer, std::io::Error> {
+ Ok(std::io::stdout())
+ }
+}
diff --git a/src/tmux.rs b/src/tmux.rs
index 88f521f..e1fa3a6 100644
--- a/src/tmux.rs
+++ b/src/tmux.rs
@@ -2,32 +2,27 @@ use std::process::Command;
use crate::Session;
+use clap::Args;
pub use error::Error;
mod error;
-#[derive(Debug)]
+#[derive(Debug, Clone, Args)]
+#[group(skip)]
pub struct Tmux {
- socket_name: String,
+ /// tmux socket-name, equivelent to `tmux -L <socket-name>`
+ #[arg(short = 'L', long = "tmux_socket", default_value = "ssh")]
+ pub socket: String,
}
impl Tmux {
const SESSION_FORMAT: &str = r##"Session(name: "#S", state: #{?session_last_attached,Attached(#{session_last_attached}),Created(#{session_created})})"##;
- pub fn new(socket_name: String) -> Self {
- Self { socket_name }
- }
-
- pub fn list(&self, name: Option<String>) -> Result<Vec<Session>, Error> {
- let filter = name
- .map(|s| vec!["-f".into(), format!("#{{==:#S,{s}}}")])
- .unwrap_or_default();
-
+ pub fn list(&self) -> Result<Vec<Session>, Error> {
let stdout = Command::new("tmux")
.arg("-L")
- .arg(&self.socket_name)
+ .arg(&self.socket)
.arg("list-sessions")
- .args(filter)
.arg("-F")
.arg(Self::SESSION_FORMAT)
.output()?
@@ -46,31 +41,13 @@ impl Tmux {
Ok(sessions)
}
-
- pub fn host(&self) -> Result<Session, Error> {
- let stdout = Command::new("tmux")
- .arg("-L")
- .arg(&self.socket_name)
- .arg("display")
- .arg("-p")
- .arg("#h")
- .output()?
- .stdout;
-
- let name = std::str::from_utf8(&stdout)?.trim().into();
-
- Ok(Session {
- state: crate::State::LocalHost,
- name,
- })
- }
}
#[cfg(test)]
mod tests {
use super::*;
- const SOCKET_NAME: &str = "test";
+ const SOCKET: &str = "test";
#[test]
fn test_tmux_list() -> Result<(), Error> {
@@ -79,52 +56,25 @@ mod tests {
for name in names.iter().cloned() {
Command::new("tmux")
.arg("-L")
- .arg(SOCKET_NAME)
+ .arg(SOCKET)
.arg("new-session")
.arg("-ds")
.arg(name)
.status()?;
}
- let tmux = Tmux::new(SOCKET_NAME.to_owned());
- let sessions: Vec<_> = tmux.list(None)?.into_iter().map(|s| s.name).collect();
-
- Command::new("tmux")
- .arg("-L")
- .arg(SOCKET_NAME)
- .arg("kill-server")
- .status()?;
-
- assert_eq!(names, sessions);
-
- Ok(())
- }
-
- #[test]
- fn test_tmux_host() -> Result<(), Error> {
- Command::new("tmux")
- .arg("-L")
- .arg(SOCKET_NAME)
- .arg("new-session")
- .arg("-d")
- .status()?;
-
- let name = hostname::get()?.to_string_lossy().into();
- let expected_session = Session {
- state: crate::State::LocalHost,
- name,
+ let tmux = Tmux {
+ socket: SOCKET.to_owned(),
};
-
- let tmux = Tmux::new(SOCKET_NAME.to_owned());
- let session = tmux.host()?;
+ let sessions: Vec<_> = tmux.list()?.into_iter().map(|s| s.name).collect();
Command::new("tmux")
.arg("-L")
- .arg(SOCKET_NAME)
+ .arg(SOCKET)
.arg("kill-server")
.status()?;
- assert_eq!(expected_session, session);
+ assert_eq!(names, sessions);
Ok(())
}
diff --git a/src/unix.rs b/src/unix.rs
index 402e476..d549b63 100644
--- a/src/unix.rs
+++ b/src/unix.rs
@@ -9,20 +9,24 @@ pub struct Hosts(Vec<Session>);
impl Hosts {
pub fn open() -> Result<Self, Error> {
- let buf_reader = BufReader::new(File::open("/etc/hosts")?);
- let inner: Vec<Session> = buf_reader
+ File::open("/etc/hosts").map(Self::parse_file).map(Self)
+ }
+
+ fn parse_file(file: File) -> Vec<Session> {
+ BufReader::new(file)
.lines()
.flatten()
+ .filter_map(Self::parse_line)
+ .collect()
+ }
+
+ fn parse_line(line: String) -> Option<Session> {
+ line.split_whitespace()
.take_while(|s| !s.starts_with('#'))
- .flat_map(|l| {
- l.split_whitespace()
- .skip(1)
- .take_while(|s| !s.starts_with('#'))
- .map(Session::from)
- .collect::<Vec<_>>()
- })
- .collect();
- Ok(Self(inner))
+ .last()
+ // Skip BOM
+ .filter(|&s| s != "\u{feff}")
+ .map(Session::discover)
}
}