diff options
-rw-r--r-- | zoned/src/api.rs | 139 | ||||
-rw-r--r-- | zoned/src/error.rs | 13 | ||||
-rw-r--r-- | zoned/src/lib.rs | 4 | ||||
-rw-r--r-- | zoned/src/main.rs | 15 | ||||
-rw-r--r-- | zoned/src/ws.rs | 106 |
5 files changed, 144 insertions, 133 deletions
diff --git a/zoned/src/api.rs b/zoned/src/api.rs index 160ab15..5236764 100644 --- a/zoned/src/api.rs +++ b/zoned/src/api.rs @@ -1,47 +1,30 @@ -use anyhow::Context; use axum::{ - extract::{ - ws::{Message, WebSocket, WebSocketUpgrade}, - Extension, Query, TypedHeader, - }, + extract::{ws::WebSocketUpgrade, Extension, Query, TypedHeader}, headers, response::IntoResponse, routing::{get, post}, Json, Router, }; -use bytes::BytesMut; -use futures::{stream::StreamExt, SinkExt}; -use serde::Deserialize; use std::sync::Arc; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - process::Command, - sync::mpsc::unbounded_channel, -}; -use tracing::{debug, error, info, warn}; -use wspty::PtyCommand; +use tracing::{info, instrument, warn}; use zone_core::{Container, ContainerOptions, FilterContainer}; use zone_nspawn::NSpawn; -use crate::{Error, Result, State}; - -#[derive(Deserialize, Debug)] -struct WindowSize { - cols: u16, - rows: u16, -} +use crate::{ws, Error, Result, State}; +#[instrument()] pub fn build_routes() -> Router { Router::new() .route("/test", get(test_endpoint)) .route("/container", post(clone_container)) .route("/container/list?<container..>", get(container_list)) - .route("/ws", get(ws_handler)) + .route("/ws", get(ws_upgrade)) } /// # Test endpoint /// /// Returns a list of containers based on the query. +#[instrument(ret, skip(state))] async fn test_endpoint(Extension(state): Extension<Arc<State>>) -> Json<String> { Json(state.zfs.config.pool_name.to_owned()) } @@ -49,6 +32,7 @@ async fn test_endpoint(Extension(state): Extension<Arc<State>>) -> Json<String> /// List containers /// /// Returns a list of containers based on the query. +#[instrument(err, ret)] async fn container_list( container: Option<Query<ContainerOptions>>, ) -> Result<Json<Vec<Container>>> { @@ -67,6 +51,7 @@ async fn container_list( /// Create container /// /// Creates a new container volume from the provided container json data +#[instrument(err, ret, skip(state))] async fn clone_container( Json(container): Json<Container>, Extension(state): Extension<Arc<State>>, @@ -82,111 +67,19 @@ async fn clone_container( .map(Container::into) } -async fn ws_handler( +/// Upgrade to websocket +/// +/// Creates a new container volume from the provided container json data +#[instrument(ret, skip_all)] +async fn ws_upgrade( ws: WebSocketUpgrade, user_agent: Option<TypedHeader<headers::UserAgent>>, Extension(state): Extension<Arc<State>>, ) -> impl IntoResponse { - info!("Client connected"); - - if let Some(TypedHeader(user_agent)) = user_agent { - debug!("`{}` connected", user_agent.as_str()); - } - - ws.on_upgrade(|socket| websocket(socket, state)) -} - -async fn websocket(ws_stream: WebSocket, _state: Arc<State>) { - debug!("Handling websocket!"); - - let (mut sender, mut receiver) = ws_stream.split(); - - let (tx, mut rx) = unbounded_channel(); - let ws_tx = tx.clone(); + let ua = user_agent.map_or("Unknown".to_string(), |u| u.to_string()); + info!(%ua, "Client connected"); - let mut cmd = Command::new("bash"); - - cmd.arg("-l").env("TERM", "xterm-256color"); - - let mut pty_cmd = PtyCommand::from(cmd); - let (kill_tx, kill_rx) = unbounded_channel(); - - let (mut pty_write, mut pty_read) = match pty_cmd.run(kill_rx).await { - Ok(pty) => (pty.clone(), pty), - Err(err) => { - error!(?err); - return; - } - }; - - let recv_task = tokio::spawn(async move { - while let Some(Ok(msg)) = receiver.next().await { - match msg { - Message::Binary(data) => match data[0] { - 0 => { - if data.len().gt(&0) { - pty_write.write_all(&data[1..]).await?; - } - } - 1 => { - let resize_msg: WindowSize = - serde_json::from_slice(&data[1..]).context("Failed to convert")?; - pty_write.resize(resize_msg.cols, resize_msg.rows)?; - } - 2 => { - tx.send(Message::Binary(vec![1u8])) - .context("Failed to send")?; - } - _ => (), - }, - Message::Ping(data) => tx.send(Message::Pong(data)).context("Failed to send")?, - _ => (), - }; - } - let _ = kill_tx - .send(()) - .map_err(|e| debug!("failed to send stop signal: {:?}", e)); - Ok(()) - }); - - let read_task = tokio::spawn(async move { - let fut = async move { - let mut buffer = BytesMut::with_capacity(1024); - buffer.resize(1024, 0u8); - loop { - buffer[0] = 0u8; - let mut tail = &mut buffer[1..]; - let n = pty_read.read_buf(&mut tail).await?; - if n == 0 { - break; - } - match ws_tx.send(Message::Binary(buffer[..n + 1].to_vec())) { - Ok(_) => (), - Err(e) => anyhow::bail!("failed to send msg to client: {:?}", e), - } - } - Ok::<(), anyhow::Error>(()) - }; - fut.await.map_err(|e| { - error!("handle pty incoming error: {:?}", &e); - e.into() - }) - }); - - let send_task = tokio::spawn(async move { - while let Some(msg) = rx.recv().await { - sender.send(msg).await?; - } - Result::Ok(()) - }); - - if let Err(err) = tokio::select! { - res = recv_task => res, - res = read_task => res, - res = send_task => res, - } { - error!(?err); - } + ws.on_upgrade(|socket| ws::handler(socket, state)) } #[cfg(test)] diff --git a/zoned/src/error.rs b/zoned/src/error.rs index 0c3bdcf..c9723e3 100644 --- a/zoned/src/error.rs +++ b/zoned/src/error.rs @@ -1,7 +1,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, - Json, + Json, extract::ws::Message, }; use serde_json::json; use thiserror::Error; @@ -15,6 +15,9 @@ pub enum Error { #[error("Container Error: {0:?}")] Container(String), + + #[error("Container not found")] + NotFound, #[error("ZFS Error: {source:?}")] ZFS { @@ -46,8 +49,12 @@ pub enum Error { source: std::io::Error, }, - #[error("Container not found")] - NotFound, + #[error("Send Error: {source:?}")] + Send { + #[from] + source: tokio::sync::mpsc::error::SendError<Message>, + }, + #[error(transparent)] Other(#[from] anyhow::Error), diff --git a/zoned/src/lib.rs b/zoned/src/lib.rs index 944a264..e118e68 100644 --- a/zoned/src/lib.rs +++ b/zoned/src/lib.rs @@ -1,9 +1,9 @@ -pub use crate::api::build_routes; pub use crate::config::Config; pub use crate::error::{Error, Result}; pub use crate::state::State; -mod api; +pub mod api; mod config; mod error; mod state; +mod ws; diff --git a/zoned/src/main.rs b/zoned/src/main.rs index 8787a4a..14b5ca1 100644 --- a/zoned/src/main.rs +++ b/zoned/src/main.rs @@ -1,13 +1,16 @@ +use std::net::SocketAddr; + use anyhow::Context; use axum::extract::Extension; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, }; -use tracing::{debug_span, info, Instrument}; -use zoned::{build_routes, Config, State}; +use tracing::{debug_span, info, instrument, Instrument}; +use zoned::{api, Config, State}; #[tokio::main] +#[instrument(err)] async fn main() -> Result<(), zoned::Error> { tracing_subscriber::fmt::init(); @@ -19,11 +22,13 @@ async fn main() -> Result<(), zoned::Error> { let shared_state = State::try_from(&config)?.into_arc(); - let routes = build_routes().layer(Extension(shared_state)); + let routes = api::build_routes().layer(Extension(shared_state)); + + let socket_addr = SocketAddr::from(config); - info!(ip_address = %config.ip_address, port = %config.port, "Server listening"); + info!("Server listening on http://{}", socket_addr); - axum::Server::bind(&config.into()) + axum::Server::bind(&socket_addr) .serve(routes.into_make_service()) .instrument(debug_span!("read_task").or_current()) .await diff --git a/zoned/src/ws.rs b/zoned/src/ws.rs new file mode 100644 index 0000000..4d441f5 --- /dev/null +++ b/zoned/src/ws.rs @@ -0,0 +1,106 @@ +use anyhow::Context; +use axum::extract::ws::{Message, WebSocket}; +use bytes::BytesMut; +use futures::{ + stream::{SplitSink, SplitStream, StreamExt}, + SinkExt, +}; +use serde::Deserialize; +use std::sync::Arc; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + process::Command, + sync::mpsc, +}; +use tracing::{instrument, warn}; +use wspty::{PtyCommand, PtyMaster}; + +use crate::{Result, State}; + +#[derive(Deserialize, Debug)] +struct WindowSize { + cols: u16, + rows: u16, +} + +#[instrument(err, skip_all)] +pub async fn handler(ws_stream: WebSocket, _state: Arc<State>) -> Result<()> { + let (sender, receiver) = ws_stream.split(); + let (kill_tx, kill_rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(1024); + + let mut cmd = Command::new("bash"); + cmd.arg("-l").env("TERM", "xterm-256color"); + + let pty = PtyCommand::from(cmd).run(kill_rx).await?; + + tokio::select! { + res = msg_handler(receiver, pty.clone(), tx.clone(), kill_tx) => res, + res = pty_handler(pty, tx) => res, + res = ws_sender(rx, sender) => res, + } +} + +#[instrument(err, skip_all)] +async fn msg_handler( + mut receiver: SplitStream<WebSocket>, + mut pty_write: PtyMaster, + tx: mpsc::Sender<Message>, + kill_tx: mpsc::UnboundedSender<()>, +) -> Result<()> { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Binary(data) => match data[0] { + 0 => { + if data.len().gt(&0) { + pty_write.write_all(&data[1..]).await?; + } + } + 1 => { + let resize_msg: WindowSize = serde_json::from_slice(&data[1..]) + .context("Failed to convert data to WindowSize")?; + pty_write.resize(resize_msg.cols, resize_msg.rows)?; + } + 2 => { + tx.send(Message::Binary(vec![1u8])).await?; + } + _ => (), + }, + Message::Ping(data) => tx.send(Message::Pong(data)).await?, + _ => (), + }; + } + + if kill_tx.send(()).is_err() { + warn!("kill signal sent to pty was never received") + }; + + Ok(()) +} + +#[instrument(err, skip_all)] +async fn pty_handler(mut pty_read: PtyMaster, tx: mpsc::Sender<Message>) -> Result<()> { + let mut buffer = BytesMut::with_capacity(1024); + buffer.resize(1024, 0u8); + loop { + buffer[0] = 0u8; + let mut tail = &mut buffer[1..]; + let n = pty_read.read_buf(&mut tail).await?; + if n == 0 { + break; + } + tx.send(Message::Binary(buffer[..n + 1].to_vec())).await? + } + Ok(()) +} + +#[instrument(err, skip_all)] +async fn ws_sender( + mut rx: mpsc::Receiver<Message>, + mut sender: SplitSink<WebSocket, Message>, +) -> Result<()> { + while let Some(msg) = rx.recv().await { + sender.send(msg).await?; + } + Ok(()) +} |