aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--zoned/src/api.rs139
-rw-r--r--zoned/src/error.rs13
-rw-r--r--zoned/src/lib.rs4
-rw-r--r--zoned/src/main.rs15
-rw-r--r--zoned/src/ws.rs106
5 files changed, 144 insertions, 133 deletions
diff --git a/zoned/src/api.rs b/zoned/src/api.rs
index 160ab15..5236764 100644
--- a/zoned/src/api.rs
+++ b/zoned/src/api.rs
@@ -1,47 +1,30 @@
-use anyhow::Context;
use axum::{
- extract::{
- ws::{Message, WebSocket, WebSocketUpgrade},
- Extension, Query, TypedHeader,
- },
+ extract::{ws::WebSocketUpgrade, Extension, Query, TypedHeader},
headers,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
-use bytes::BytesMut;
-use futures::{stream::StreamExt, SinkExt};
-use serde::Deserialize;
use std::sync::Arc;
-use tokio::{
- io::{AsyncReadExt, AsyncWriteExt},
- process::Command,
- sync::mpsc::unbounded_channel,
-};
-use tracing::{debug, error, info, warn};
-use wspty::PtyCommand;
+use tracing::{info, instrument, warn};
use zone_core::{Container, ContainerOptions, FilterContainer};
use zone_nspawn::NSpawn;
-use crate::{Error, Result, State};
-
-#[derive(Deserialize, Debug)]
-struct WindowSize {
- cols: u16,
- rows: u16,
-}
+use crate::{ws, Error, Result, State};
+#[instrument()]
pub fn build_routes() -> Router {
Router::new()
.route("/test", get(test_endpoint))
.route("/container", post(clone_container))
.route("/container/list?<container..>", get(container_list))
- .route("/ws", get(ws_handler))
+ .route("/ws", get(ws_upgrade))
}
/// # Test endpoint
///
/// Returns a list of containers based on the query.
+#[instrument(ret, skip(state))]
async fn test_endpoint(Extension(state): Extension<Arc<State>>) -> Json<String> {
Json(state.zfs.config.pool_name.to_owned())
}
@@ -49,6 +32,7 @@ async fn test_endpoint(Extension(state): Extension<Arc<State>>) -> Json<String>
/// List containers
///
/// Returns a list of containers based on the query.
+#[instrument(err, ret)]
async fn container_list(
container: Option<Query<ContainerOptions>>,
) -> Result<Json<Vec<Container>>> {
@@ -67,6 +51,7 @@ async fn container_list(
/// Create container
///
/// Creates a new container volume from the provided container json data
+#[instrument(err, ret, skip(state))]
async fn clone_container(
Json(container): Json<Container>,
Extension(state): Extension<Arc<State>>,
@@ -82,111 +67,19 @@ async fn clone_container(
.map(Container::into)
}
-async fn ws_handler(
+/// Upgrade to websocket
+///
+/// Creates a new container volume from the provided container json data
+#[instrument(ret, skip_all)]
+async fn ws_upgrade(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Extension(state): Extension<Arc<State>>,
) -> impl IntoResponse {
- info!("Client connected");
-
- if let Some(TypedHeader(user_agent)) = user_agent {
- debug!("`{}` connected", user_agent.as_str());
- }
-
- ws.on_upgrade(|socket| websocket(socket, state))
-}
-
-async fn websocket(ws_stream: WebSocket, _state: Arc<State>) {
- debug!("Handling websocket!");
-
- let (mut sender, mut receiver) = ws_stream.split();
-
- let (tx, mut rx) = unbounded_channel();
- let ws_tx = tx.clone();
+ let ua = user_agent.map_or("Unknown".to_string(), |u| u.to_string());
+ info!(%ua, "Client connected");
- let mut cmd = Command::new("bash");
-
- cmd.arg("-l").env("TERM", "xterm-256color");
-
- let mut pty_cmd = PtyCommand::from(cmd);
- let (kill_tx, kill_rx) = unbounded_channel();
-
- let (mut pty_write, mut pty_read) = match pty_cmd.run(kill_rx).await {
- Ok(pty) => (pty.clone(), pty),
- Err(err) => {
- error!(?err);
- return;
- }
- };
-
- let recv_task = tokio::spawn(async move {
- while let Some(Ok(msg)) = receiver.next().await {
- match msg {
- Message::Binary(data) => match data[0] {
- 0 => {
- if data.len().gt(&0) {
- pty_write.write_all(&data[1..]).await?;
- }
- }
- 1 => {
- let resize_msg: WindowSize =
- serde_json::from_slice(&data[1..]).context("Failed to convert")?;
- pty_write.resize(resize_msg.cols, resize_msg.rows)?;
- }
- 2 => {
- tx.send(Message::Binary(vec![1u8]))
- .context("Failed to send")?;
- }
- _ => (),
- },
- Message::Ping(data) => tx.send(Message::Pong(data)).context("Failed to send")?,
- _ => (),
- };
- }
- let _ = kill_tx
- .send(())
- .map_err(|e| debug!("failed to send stop signal: {:?}", e));
- Ok(())
- });
-
- let read_task = tokio::spawn(async move {
- let fut = async move {
- let mut buffer = BytesMut::with_capacity(1024);
- buffer.resize(1024, 0u8);
- loop {
- buffer[0] = 0u8;
- let mut tail = &mut buffer[1..];
- let n = pty_read.read_buf(&mut tail).await?;
- if n == 0 {
- break;
- }
- match ws_tx.send(Message::Binary(buffer[..n + 1].to_vec())) {
- Ok(_) => (),
- Err(e) => anyhow::bail!("failed to send msg to client: {:?}", e),
- }
- }
- Ok::<(), anyhow::Error>(())
- };
- fut.await.map_err(|e| {
- error!("handle pty incoming error: {:?}", &e);
- e.into()
- })
- });
-
- let send_task = tokio::spawn(async move {
- while let Some(msg) = rx.recv().await {
- sender.send(msg).await?;
- }
- Result::Ok(())
- });
-
- if let Err(err) = tokio::select! {
- res = recv_task => res,
- res = read_task => res,
- res = send_task => res,
- } {
- error!(?err);
- }
+ ws.on_upgrade(|socket| ws::handler(socket, state))
}
#[cfg(test)]
diff --git a/zoned/src/error.rs b/zoned/src/error.rs
index 0c3bdcf..c9723e3 100644
--- a/zoned/src/error.rs
+++ b/zoned/src/error.rs
@@ -1,7 +1,7 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
- Json,
+ Json, extract::ws::Message,
};
use serde_json::json;
use thiserror::Error;
@@ -15,6 +15,9 @@ pub enum Error {
#[error("Container Error: {0:?}")]
Container(String),
+
+ #[error("Container not found")]
+ NotFound,
#[error("ZFS Error: {source:?}")]
ZFS {
@@ -46,8 +49,12 @@ pub enum Error {
source: std::io::Error,
},
- #[error("Container not found")]
- NotFound,
+ #[error("Send Error: {source:?}")]
+ Send {
+ #[from]
+ source: tokio::sync::mpsc::error::SendError<Message>,
+ },
+
#[error(transparent)]
Other(#[from] anyhow::Error),
diff --git a/zoned/src/lib.rs b/zoned/src/lib.rs
index 944a264..e118e68 100644
--- a/zoned/src/lib.rs
+++ b/zoned/src/lib.rs
@@ -1,9 +1,9 @@
-pub use crate::api::build_routes;
pub use crate::config::Config;
pub use crate::error::{Error, Result};
pub use crate::state::State;
-mod api;
+pub mod api;
mod config;
mod error;
mod state;
+mod ws;
diff --git a/zoned/src/main.rs b/zoned/src/main.rs
index 8787a4a..14b5ca1 100644
--- a/zoned/src/main.rs
+++ b/zoned/src/main.rs
@@ -1,13 +1,16 @@
+use std::net::SocketAddr;
+
use anyhow::Context;
use axum::extract::Extension;
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
-use tracing::{debug_span, info, Instrument};
-use zoned::{build_routes, Config, State};
+use tracing::{debug_span, info, instrument, Instrument};
+use zoned::{api, Config, State};
#[tokio::main]
+#[instrument(err)]
async fn main() -> Result<(), zoned::Error> {
tracing_subscriber::fmt::init();
@@ -19,11 +22,13 @@ async fn main() -> Result<(), zoned::Error> {
let shared_state = State::try_from(&config)?.into_arc();
- let routes = build_routes().layer(Extension(shared_state));
+ let routes = api::build_routes().layer(Extension(shared_state));
+
+ let socket_addr = SocketAddr::from(config);
- info!(ip_address = %config.ip_address, port = %config.port, "Server listening");
+ info!("Server listening on http://{}", socket_addr);
- axum::Server::bind(&config.into())
+ axum::Server::bind(&socket_addr)
.serve(routes.into_make_service())
.instrument(debug_span!("read_task").or_current())
.await
diff --git a/zoned/src/ws.rs b/zoned/src/ws.rs
new file mode 100644
index 0000000..4d441f5
--- /dev/null
+++ b/zoned/src/ws.rs
@@ -0,0 +1,106 @@
+use anyhow::Context;
+use axum::extract::ws::{Message, WebSocket};
+use bytes::BytesMut;
+use futures::{
+ stream::{SplitSink, SplitStream, StreamExt},
+ SinkExt,
+};
+use serde::Deserialize;
+use std::sync::Arc;
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ process::Command,
+ sync::mpsc,
+};
+use tracing::{instrument, warn};
+use wspty::{PtyCommand, PtyMaster};
+
+use crate::{Result, State};
+
+#[derive(Deserialize, Debug)]
+struct WindowSize {
+ cols: u16,
+ rows: u16,
+}
+
+#[instrument(err, skip_all)]
+pub async fn handler(ws_stream: WebSocket, _state: Arc<State>) -> Result<()> {
+ let (sender, receiver) = ws_stream.split();
+ let (kill_tx, kill_rx) = mpsc::unbounded_channel();
+ let (tx, rx) = mpsc::channel(1024);
+
+ let mut cmd = Command::new("bash");
+ cmd.arg("-l").env("TERM", "xterm-256color");
+
+ let pty = PtyCommand::from(cmd).run(kill_rx).await?;
+
+ tokio::select! {
+ res = msg_handler(receiver, pty.clone(), tx.clone(), kill_tx) => res,
+ res = pty_handler(pty, tx) => res,
+ res = ws_sender(rx, sender) => res,
+ }
+}
+
+#[instrument(err, skip_all)]
+async fn msg_handler(
+ mut receiver: SplitStream<WebSocket>,
+ mut pty_write: PtyMaster,
+ tx: mpsc::Sender<Message>,
+ kill_tx: mpsc::UnboundedSender<()>,
+) -> Result<()> {
+ while let Some(Ok(msg)) = receiver.next().await {
+ match msg {
+ Message::Binary(data) => match data[0] {
+ 0 => {
+ if data.len().gt(&0) {
+ pty_write.write_all(&data[1..]).await?;
+ }
+ }
+ 1 => {
+ let resize_msg: WindowSize = serde_json::from_slice(&data[1..])
+ .context("Failed to convert data to WindowSize")?;
+ pty_write.resize(resize_msg.cols, resize_msg.rows)?;
+ }
+ 2 => {
+ tx.send(Message::Binary(vec![1u8])).await?;
+ }
+ _ => (),
+ },
+ Message::Ping(data) => tx.send(Message::Pong(data)).await?,
+ _ => (),
+ };
+ }
+
+ if kill_tx.send(()).is_err() {
+ warn!("kill signal sent to pty was never received")
+ };
+
+ Ok(())
+}
+
+#[instrument(err, skip_all)]
+async fn pty_handler(mut pty_read: PtyMaster, tx: mpsc::Sender<Message>) -> Result<()> {
+ let mut buffer = BytesMut::with_capacity(1024);
+ buffer.resize(1024, 0u8);
+ loop {
+ buffer[0] = 0u8;
+ let mut tail = &mut buffer[1..];
+ let n = pty_read.read_buf(&mut tail).await?;
+ if n == 0 {
+ break;
+ }
+ tx.send(Message::Binary(buffer[..n + 1].to_vec())).await?
+ }
+ Ok(())
+}
+
+#[instrument(err, skip_all)]
+async fn ws_sender(
+ mut rx: mpsc::Receiver<Message>,
+ mut sender: SplitSink<WebSocket, Message>,
+) -> Result<()> {
+ while let Some(msg) = rx.recv().await {
+ sender.send(msg).await?;
+ }
+ Ok(())
+}