diff options
author | Toby Vincent <tobyv13@gmail.com> | 2022-03-20 22:46:53 -0500 |
---|---|---|
committer | Toby Vincent <tobyv13@gmail.com> | 2022-03-20 22:46:53 -0500 |
commit | e95e9506a30736876a7598c30c018af51256f1ff (patch) | |
tree | 33c01bc7228f40bd7616085cc58a93fc948a22f7 /zoned/src | |
parent | 2d377920fd2b624a2a58a051607152ab324a8614 (diff) |
feat: implement pty over websocket
Diffstat (limited to 'zoned/src')
-rw-r--r-- | zoned/src/api.rs | 196 | ||||
-rw-r--r-- | zoned/src/config.rs | 7 | ||||
-rw-r--r-- | zoned/src/error.rs | 15 | ||||
-rw-r--r-- | zoned/src/main.rs | 25 |
4 files changed, 160 insertions, 83 deletions
diff --git a/zoned/src/api.rs b/zoned/src/api.rs index 3e5c85e..6d48aa6 100644 --- a/zoned/src/api.rs +++ b/zoned/src/api.rs @@ -1,20 +1,39 @@ +use anyhow::Context; use axum::{ extract::{ - Extension, - Query, ws::{Message, WebSocket, WebSocketUpgrade}, - TypedHeader, + Extension, Query, TypedHeader, }, + headers, + response::IntoResponse, routing::{get, post}, - Json, Router, response::IntoResponse, headers, + Json, Router, }; -use std::{sync::Arc, process::Command}; -use tracing::warn; +use bytes::BytesMut; +use futures::{ + stream::{SplitSink, SplitStream, StreamExt}, + SinkExt, +}; +use serde::Deserialize; +use std::sync::Arc; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + process::Command, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, +}; +use tracing::{debug, error, info, warn}; +use wspty::{PtyCommand, PtyMaster}; use zone_core::{Container, ContainerOptions, FilterContainer}; use zone_nspawn::NSpawn; use crate::{Error, Result, State}; +#[derive(Deserialize, Debug)] +struct WindowSize { + cols: u16, + rows: u16, +} + pub fn build_routes() -> Router { Router::new() .route("/test", get(test_endpoint)) @@ -69,74 +88,125 @@ async fn clone_container( async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option<TypedHeader<headers::UserAgent>>, + Extension(state): Extension<Arc<State>>, ) -> impl IntoResponse { + info!("Client connected"); + if let Some(TypedHeader(user_agent)) = user_agent { - println!("`{}` connected", user_agent.as_str()); + debug!("`{}` connected", user_agent.as_str()); } - ws.on_upgrade(handle_socket) + ws.on_upgrade(|socket| websocket(socket, state)) } -async fn handle_socket(mut socket: WebSocket) { - let mut term_input: Vec<String> = Vec::new(); - - loop { - if let Some(msg) = socket.recv().await { - if let Ok(msg) = msg { - match msg { - Message::Text(t) => { - println!("client send str: {:?}", t); - // Enter - if t.eq("\r") { - let response = parse_command(term_input.to_owned()).await; - if socket.send(Message::Text(response)).await.is_err() { - println!("send failed"); - return; - } - term_input = Vec::new(); - // Backspace - } else if t.eq("\u{7f}") { - term_input.pop(); - } else { - term_input.push(t.to_owned()); - } - } - Message::Close(_) => { - println!("client disconnected"); - return; +async fn websocket(ws_stream: WebSocket, _state: Arc<State>) -> Result<()> { + debug!("Handling websocket!"); + + let (sender, receiver) = unbounded_channel(); + let ws_sender = sender.clone(); + + let (ws_outgoing, ws_incoming) = ws_stream.split(); + + let mut cmd = Command::new("bash"); + + cmd.arg("-l").env("TERM", "xterm-256color"); + + let mut pty_cmd = PtyCommand::from(cmd); + let (stop_sender, stop_receiver) = unbounded_channel(); + let pty_master = pty_cmd.run(stop_receiver).await?; + + let pty_shell_writer = pty_master.clone(); + let pty_shell_reader = pty_master.clone(); + + let res = tokio::select! { + res = handle_websocket_incoming(ws_incoming, pty_shell_writer, sender, stop_sender) => res, + res = handle_pty_incoming(pty_shell_reader, ws_sender) => res, + res = write_to_websocket(ws_outgoing, receiver) => res, + }; + debug!("res = {:?}", res); + Ok(()) +} + +async fn handle_websocket_incoming( + mut incoming: SplitStream<WebSocket>, + mut pty_shell_writer: PtyMaster, + websocket_sender: UnboundedSender<Message>, + stop_sender: UnboundedSender<()>, +) -> Result<()> { + while let Some(Ok(msg)) = incoming.next().await { + match msg { + Message::Binary(data) => match data[0] { + 0 => { + if data.len().gt(&0) { + pty_shell_writer.write_all(&data[1..]).await?; } - _ => {return;} } - } else { - println!("client disconnected"); - return; - } - } + 1 => { + let resize_msg: WindowSize = + serde_json::from_slice(&data[1..]).context("Failed to convert")?; + pty_shell_writer.resize(resize_msg.cols, resize_msg.rows)?; + } + 2 => { + websocket_sender + .send(Message::Binary(vec![1u8])) + .context("Failed to send")?; + } + _ => (), + }, + Message::Ping(data) => websocket_sender + .send(Message::Pong(data)) + .context("Failed to send")?, + _ => (), + }; } + let _ = stop_sender + .send(()) + .map_err(|e| debug!("failed to send stop signal: {:?}", e)); + Ok(()) } -async fn parse_command(input: Vec<String>) -> String { - let temp = input.concat(); - let cmd: Vec<&str> = temp.split(' ').collect(); - - let output = if cmd.len() <= 1 { - Command::new(cmd[0]) - .output().ok() - } else { - let mut args = cmd.to_owned(); - args.remove(0); - - Command::new(cmd[0]) - .args(args) - .output().ok() +async fn handle_pty_incoming( + mut pty_shell_reader: PtyMaster, + websocket_sender: UnboundedSender<Message>, +) -> Result<()> { + let fut = async move { + let mut buffer = BytesMut::with_capacity(1024); + buffer.resize(1024, 0u8); + loop { + buffer[0] = 0u8; + let mut tail = &mut buffer[1..]; + let n = pty_shell_reader.read_buf(&mut tail).await?; + if n == 0 { + break; + } + match websocket_sender.send(Message::Binary(buffer[..n + 1].to_vec())) { + Ok(_) => (), + Err(e) => anyhow::bail!("failed to send msg to client: {:?}", e), + } + } + Ok::<(), anyhow::Error>(()) }; + fut.await.map_err(|e| { + error!("handle pty incoming error: {:?}", &e); + e.into() + }) +} - match output { - Some(x) => if x.status.success() { - String::from_utf8(x.stdout).unwrap() - } else { - String::from_utf8(x.stderr).unwrap() - } - _ => String::from("") +async fn write_to_websocket( + mut outgoing: SplitSink<WebSocket, Message>, + mut receiver: UnboundedReceiver<Message>, +) -> Result<()> { + while let Some(msg) = receiver.recv().await { + outgoing.send(msg).await?; } -}
\ No newline at end of file + Ok(()) +} + +#[cfg(test)] +mod tests { + #[test] + fn hello_world() { + // use super::*; + assert!("true" == "true"); + } +} diff --git a/zoned/src/config.rs b/zoned/src/config.rs index f6622f1..d889cce 100644 --- a/zoned/src/config.rs +++ b/zoned/src/config.rs @@ -1,6 +1,6 @@ use figment::Figment; use serde::{Deserialize, Serialize}; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use crate::{Error, Result}; @@ -31,11 +31,6 @@ impl TryFrom<Figment> for Config { } } -impl From<Config> for SocketAddr { - fn from(val: Config) -> Self { - SocketAddr::from((val.ip_address, val.port)) - } -} #[cfg(test)] mod tests { use std::path::PathBuf; diff --git a/zoned/src/error.rs b/zoned/src/error.rs index 98fbf2f..0c3bdcf 100644 --- a/zoned/src/error.rs +++ b/zoned/src/error.rs @@ -10,6 +10,9 @@ pub type Result<T> = std::result::Result<T, Error>; #[derive(Error, Debug)] pub enum Error { + #[error("Zone Error: {0:?}")] + Zone(String), + #[error("Container Error: {0:?}")] Container(String), @@ -31,6 +34,18 @@ pub enum Error { source: figment::Error, }, + #[error("Axum Error: {source:?}")] + Axum { + #[from] + source: axum::Error, + }, + + #[error("IO Error: {source:?}")] + IO { + #[from] + source: std::io::Error, + }, + #[error("Container not found")] NotFound, diff --git a/zoned/src/main.rs b/zoned/src/main.rs index 019c9e0..f275e7a 100644 --- a/zoned/src/main.rs +++ b/zoned/src/main.rs @@ -1,14 +1,15 @@ +use anyhow::Context; use axum::extract::Extension; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, }; use std::{net::SocketAddr, sync::Arc}; -use tracing::{debug, error}; +use tracing::info; use zoned::{build_routes, Config, State}; #[tokio::main] -async fn main() { +async fn main() -> Result<(), zoned::Error> { tracing_subscriber::fmt::init(); let figment = Figment::from(Serialized::defaults(Config::default())) @@ -17,26 +18,22 @@ async fn main() { let config = match Config::try_from(figment) { Ok(config) => config, - Err(err) => { - error!("{}", err); - std::process::exit(1) - } + Err(err) => return Err(err), }; + let addr = SocketAddr::from((config.ip_address, config.port)); + let shared_state = match State::try_from(config) { - Ok(state) => Arc::new(state), - Err(err) => { - error!("{}", err); - std::process::exit(1) - } + Ok(config) => Arc::new(config), + Err(err) => return Err(err), }; let routes = build_routes().layer(Extension(shared_state)); - let addr = SocketAddr::from(([172, 21, 110, 173], 3001)); - debug!("listening on {}", addr); + info!("listening on {}", addr); axum::Server::bind(&addr) .serve(routes.into_make_service()) .await - .unwrap(); + .context("Axum error") + .map_err(zoned::Error::from) } |