aboutsummaryrefslogtreecommitdiffstats
path: root/zoned/src
diff options
context:
space:
mode:
authorToby Vincent <tobyv13@gmail.com>2022-03-20 22:46:53 -0500
committerToby Vincent <tobyv13@gmail.com>2022-03-20 22:46:53 -0500
commite95e9506a30736876a7598c30c018af51256f1ff (patch)
tree33c01bc7228f40bd7616085cc58a93fc948a22f7 /zoned/src
parent2d377920fd2b624a2a58a051607152ab324a8614 (diff)
feat: implement pty over websocket
Diffstat (limited to 'zoned/src')
-rw-r--r--zoned/src/api.rs196
-rw-r--r--zoned/src/config.rs7
-rw-r--r--zoned/src/error.rs15
-rw-r--r--zoned/src/main.rs25
4 files changed, 160 insertions, 83 deletions
diff --git a/zoned/src/api.rs b/zoned/src/api.rs
index 3e5c85e..6d48aa6 100644
--- a/zoned/src/api.rs
+++ b/zoned/src/api.rs
@@ -1,20 +1,39 @@
+use anyhow::Context;
use axum::{
extract::{
- Extension,
- Query,
ws::{Message, WebSocket, WebSocketUpgrade},
- TypedHeader,
+ Extension, Query, TypedHeader,
},
+ headers,
+ response::IntoResponse,
routing::{get, post},
- Json, Router, response::IntoResponse, headers,
+ Json, Router,
};
-use std::{sync::Arc, process::Command};
-use tracing::warn;
+use bytes::BytesMut;
+use futures::{
+ stream::{SplitSink, SplitStream, StreamExt},
+ SinkExt,
+};
+use serde::Deserialize;
+use std::sync::Arc;
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ process::Command,
+ sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
+};
+use tracing::{debug, error, info, warn};
+use wspty::{PtyCommand, PtyMaster};
use zone_core::{Container, ContainerOptions, FilterContainer};
use zone_nspawn::NSpawn;
use crate::{Error, Result, State};
+#[derive(Deserialize, Debug)]
+struct WindowSize {
+ cols: u16,
+ rows: u16,
+}
+
pub fn build_routes() -> Router {
Router::new()
.route("/test", get(test_endpoint))
@@ -69,74 +88,125 @@ async fn clone_container(
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
+ Extension(state): Extension<Arc<State>>,
) -> impl IntoResponse {
+ info!("Client connected");
+
if let Some(TypedHeader(user_agent)) = user_agent {
- println!("`{}` connected", user_agent.as_str());
+ debug!("`{}` connected", user_agent.as_str());
}
- ws.on_upgrade(handle_socket)
+ ws.on_upgrade(|socket| websocket(socket, state))
}
-async fn handle_socket(mut socket: WebSocket) {
- let mut term_input: Vec<String> = Vec::new();
-
- loop {
- if let Some(msg) = socket.recv().await {
- if let Ok(msg) = msg {
- match msg {
- Message::Text(t) => {
- println!("client send str: {:?}", t);
- // Enter
- if t.eq("\r") {
- let response = parse_command(term_input.to_owned()).await;
- if socket.send(Message::Text(response)).await.is_err() {
- println!("send failed");
- return;
- }
- term_input = Vec::new();
- // Backspace
- } else if t.eq("\u{7f}") {
- term_input.pop();
- } else {
- term_input.push(t.to_owned());
- }
- }
- Message::Close(_) => {
- println!("client disconnected");
- return;
+async fn websocket(ws_stream: WebSocket, _state: Arc<State>) -> Result<()> {
+ debug!("Handling websocket!");
+
+ let (sender, receiver) = unbounded_channel();
+ let ws_sender = sender.clone();
+
+ let (ws_outgoing, ws_incoming) = ws_stream.split();
+
+ let mut cmd = Command::new("bash");
+
+ cmd.arg("-l").env("TERM", "xterm-256color");
+
+ let mut pty_cmd = PtyCommand::from(cmd);
+ let (stop_sender, stop_receiver) = unbounded_channel();
+ let pty_master = pty_cmd.run(stop_receiver).await?;
+
+ let pty_shell_writer = pty_master.clone();
+ let pty_shell_reader = pty_master.clone();
+
+ let res = tokio::select! {
+ res = handle_websocket_incoming(ws_incoming, pty_shell_writer, sender, stop_sender) => res,
+ res = handle_pty_incoming(pty_shell_reader, ws_sender) => res,
+ res = write_to_websocket(ws_outgoing, receiver) => res,
+ };
+ debug!("res = {:?}", res);
+ Ok(())
+}
+
+async fn handle_websocket_incoming(
+ mut incoming: SplitStream<WebSocket>,
+ mut pty_shell_writer: PtyMaster,
+ websocket_sender: UnboundedSender<Message>,
+ stop_sender: UnboundedSender<()>,
+) -> Result<()> {
+ while let Some(Ok(msg)) = incoming.next().await {
+ match msg {
+ Message::Binary(data) => match data[0] {
+ 0 => {
+ if data.len().gt(&0) {
+ pty_shell_writer.write_all(&data[1..]).await?;
}
- _ => {return;}
}
- } else {
- println!("client disconnected");
- return;
- }
- }
+ 1 => {
+ let resize_msg: WindowSize =
+ serde_json::from_slice(&data[1..]).context("Failed to convert")?;
+ pty_shell_writer.resize(resize_msg.cols, resize_msg.rows)?;
+ }
+ 2 => {
+ websocket_sender
+ .send(Message::Binary(vec![1u8]))
+ .context("Failed to send")?;
+ }
+ _ => (),
+ },
+ Message::Ping(data) => websocket_sender
+ .send(Message::Pong(data))
+ .context("Failed to send")?,
+ _ => (),
+ };
}
+ let _ = stop_sender
+ .send(())
+ .map_err(|e| debug!("failed to send stop signal: {:?}", e));
+ Ok(())
}
-async fn parse_command(input: Vec<String>) -> String {
- let temp = input.concat();
- let cmd: Vec<&str> = temp.split(' ').collect();
-
- let output = if cmd.len() <= 1 {
- Command::new(cmd[0])
- .output().ok()
- } else {
- let mut args = cmd.to_owned();
- args.remove(0);
-
- Command::new(cmd[0])
- .args(args)
- .output().ok()
+async fn handle_pty_incoming(
+ mut pty_shell_reader: PtyMaster,
+ websocket_sender: UnboundedSender<Message>,
+) -> Result<()> {
+ let fut = async move {
+ let mut buffer = BytesMut::with_capacity(1024);
+ buffer.resize(1024, 0u8);
+ loop {
+ buffer[0] = 0u8;
+ let mut tail = &mut buffer[1..];
+ let n = pty_shell_reader.read_buf(&mut tail).await?;
+ if n == 0 {
+ break;
+ }
+ match websocket_sender.send(Message::Binary(buffer[..n + 1].to_vec())) {
+ Ok(_) => (),
+ Err(e) => anyhow::bail!("failed to send msg to client: {:?}", e),
+ }
+ }
+ Ok::<(), anyhow::Error>(())
};
+ fut.await.map_err(|e| {
+ error!("handle pty incoming error: {:?}", &e);
+ e.into()
+ })
+}
- match output {
- Some(x) => if x.status.success() {
- String::from_utf8(x.stdout).unwrap()
- } else {
- String::from_utf8(x.stderr).unwrap()
- }
- _ => String::from("")
+async fn write_to_websocket(
+ mut outgoing: SplitSink<WebSocket, Message>,
+ mut receiver: UnboundedReceiver<Message>,
+) -> Result<()> {
+ while let Some(msg) = receiver.recv().await {
+ outgoing.send(msg).await?;
}
-} \ No newline at end of file
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ #[test]
+ fn hello_world() {
+ // use super::*;
+ assert!("true" == "true");
+ }
+}
diff --git a/zoned/src/config.rs b/zoned/src/config.rs
index f6622f1..d889cce 100644
--- a/zoned/src/config.rs
+++ b/zoned/src/config.rs
@@ -1,6 +1,6 @@
use figment::Figment;
use serde::{Deserialize, Serialize};
-use std::net::{IpAddr, SocketAddr};
+use std::net::IpAddr;
use crate::{Error, Result};
@@ -31,11 +31,6 @@ impl TryFrom<Figment> for Config {
}
}
-impl From<Config> for SocketAddr {
- fn from(val: Config) -> Self {
- SocketAddr::from((val.ip_address, val.port))
- }
-}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
diff --git a/zoned/src/error.rs b/zoned/src/error.rs
index 98fbf2f..0c3bdcf 100644
--- a/zoned/src/error.rs
+++ b/zoned/src/error.rs
@@ -10,6 +10,9 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
+ #[error("Zone Error: {0:?}")]
+ Zone(String),
+
#[error("Container Error: {0:?}")]
Container(String),
@@ -31,6 +34,18 @@ pub enum Error {
source: figment::Error,
},
+ #[error("Axum Error: {source:?}")]
+ Axum {
+ #[from]
+ source: axum::Error,
+ },
+
+ #[error("IO Error: {source:?}")]
+ IO {
+ #[from]
+ source: std::io::Error,
+ },
+
#[error("Container not found")]
NotFound,
diff --git a/zoned/src/main.rs b/zoned/src/main.rs
index 019c9e0..f275e7a 100644
--- a/zoned/src/main.rs
+++ b/zoned/src/main.rs
@@ -1,14 +1,15 @@
+use anyhow::Context;
use axum::extract::Extension;
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use std::{net::SocketAddr, sync::Arc};
-use tracing::{debug, error};
+use tracing::info;
use zoned::{build_routes, Config, State};
#[tokio::main]
-async fn main() {
+async fn main() -> Result<(), zoned::Error> {
tracing_subscriber::fmt::init();
let figment = Figment::from(Serialized::defaults(Config::default()))
@@ -17,26 +18,22 @@ async fn main() {
let config = match Config::try_from(figment) {
Ok(config) => config,
- Err(err) => {
- error!("{}", err);
- std::process::exit(1)
- }
+ Err(err) => return Err(err),
};
+ let addr = SocketAddr::from((config.ip_address, config.port));
+
let shared_state = match State::try_from(config) {
- Ok(state) => Arc::new(state),
- Err(err) => {
- error!("{}", err);
- std::process::exit(1)
- }
+ Ok(config) => Arc::new(config),
+ Err(err) => return Err(err),
};
let routes = build_routes().layer(Extension(shared_state));
- let addr = SocketAddr::from(([172, 21, 110, 173], 3001));
- debug!("listening on {}", addr);
+ info!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(routes.into_make_service())
.await
- .unwrap();
+ .context("Axum error")
+ .map_err(zoned::Error::from)
}