diff options
author | Toby Vincent <tobyv@tobyvin.dev> | 2024-09-03 01:47:16 -0500 |
---|---|---|
committer | Toby Vincent <tobyv@tobyvin.dev> | 2024-09-03 01:47:16 -0500 |
commit | 7bd959111de6ce30eff9f088c0253a8fa2abf056 (patch) | |
tree | 334e5a0d3eb8a1d3304559b4f415e4e879bb4d09 | |
parent | d03160fd62335e443377c89449155f6634a72140 (diff) |
-rw-r--r-- | src/main.rs | 132 |
1 files changed, 92 insertions, 40 deletions
diff --git a/src/main.rs b/src/main.rs index 29dbb88..2982a08 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use std::{ marker::Unpin, + net::SocketAddr, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -8,18 +9,21 @@ use std::{ use tokio::{ io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}, - net::{TcpListener, TcpStream, ToSocketAddrs}, + net::{TcpListener, TcpStream}, sync::{Notify, RwLock}, task::{JoinSet, LocalSet}, }; #[derive(Debug, Default)] struct State { + remote_addr: String, + local_addr: String, headers: RwLock<Vec<u8>>, - frame_buf: RwLock<Vec<u8>>, - frame_ready: Notify, - client_req: AtomicUsize, - client_req_notify: Notify, + part: RwLock<Vec<u8>>, + jpeg: RwLock<Vec<u8>>, + part_ready: Notify, + clients: AtomicUsize, + client_ready: Notify, } #[tokio::main] @@ -38,29 +42,21 @@ async fn main() -> Result<(), Error> { let mut join_set = JoinSet::new(); let state = Arc::new(State { - frame_buf: RwLock::new(Vec::with_capacity(16_384)), + jpeg: RwLock::new(Vec::with_capacity(16_384)), + remote_addr, + local_addr, ..Default::default() }); - tokio::spawn(handle_server(remote_addr, state.clone())); + tokio::spawn(handle_server(state.clone())); - println!("listening on: {local_addr}"); - let listener = TcpListener::bind(local_addr).await?; + println!("listening on: {}", state.local_addr); + let listener = TcpListener::bind(&state.local_addr).await?; let local = LocalSet::new(); local .run_until(async move { - while let Ok((stream, addr)) = listener.accept().await { - let state = state.clone(); - state.client_req.fetch_add(1, Ordering::Relaxed); - state.client_req_notify.notify_one(); - - join_set.spawn_local(async move { - println!("Client connected: {addr}"); - let res = handle_client(stream, state.clone()).await; - println!("Client disconnected: {addr}"); - state.client_req.fetch_sub(1, Ordering::Relaxed); - res - }); + while let Ok(conn) = listener.accept().await { + join_set.spawn_local(handle_client(conn, state.clone())); } }) .await; @@ -68,17 +64,69 @@ async fn main() -> Result<(), Error> { Ok(()) } -async fn handle_client( - mut stream: impl AsyncRead + AsyncWriteExt + Unpin, +async fn handle_client<A>((stream, addr): (A, SocketAddr), state: Arc<State>) -> Result<(), Error> +where + A: AsyncRead + AsyncWriteExt + Unpin, +{ + let (mut reader, writer) = tokio::io::split(stream); + + let mut buf_reader = BufReader::new(&mut reader); + let mut buf = Vec::new(); + + buf_reader.read_until(b'\n', &mut buf).await?; + let res = match buf.as_slice() { + b"GET /jpg HTTP/1.1\r\n" | b"GET /jpeg HTTP/1.1\r\n" | b"GET /snapshot HTTP/1.1\r\n" => { + println!("Client snapshot: {addr}"); + handle_jpg(writer, state.clone()).await + } + b"GET /mjpeg HTTP/1.1\r\n" | b"GET /mjpeg/1 HTTP/1.1\r\n" | b"GET /stream HTTP/1.1\r\n" => { + println!("Client connected: {addr}"); + state.clients.fetch_add(1, Ordering::Relaxed); + state.client_ready.notify_one(); + let res = handle_mjpeg(writer, state.clone()).await; + state.clients.fetch_sub(1, Ordering::Relaxed); + println!("Client disconnected: {addr}"); + res + } + _ => Err(Error::InvalidRoute), + }; + + res +} + +async fn handle_jpg<A>(mut writer: A, state: Arc<State>) -> Result<(), Error> +where + A: AsyncWriteExt + Unpin, +{ + if state.clients.load(Ordering::Relaxed) == 0 { + let mut stream = TcpStream::connect(&state.remote_addr).await?; + let req = b"GET /jpg HTTP/1.1\r\nUser-Agent: MjpegProxy/0.1.0\r\nAccept: */*\r\n\r\n"; + stream.write_all(req).await?; + tokio::io::copy(&mut stream, &mut writer).await?; + Ok(()) + } else { + let headers = b"HTTP/1.1 200 OK\r\nContent-disposition: inline; filename=capture.jpg\r\nContent-type: image/jpeg\r\n\r\n"; + writer.write_all(headers).await?; + state.part_ready.notified().await; + let jpeg = state.jpeg.read().await; + writer.write_all(&jpeg).await?; + Ok(()) + } +} + +async fn handle_mjpeg( + mut stream: impl AsyncWriteExt + Unpin, state: Arc<State>, ) -> Result<(), Error> { let headers = state.headers.read().await; stream.write_all(&headers).await?; loop { - state.frame_ready.notified().await; - let frame_buffer = state.frame_buf.read().await; - stream.write_all(&frame_buffer).await? + state.part_ready.notified().await; + let part = state.part.read().await; + stream.write_all(&part).await?; + let jpeg = state.jpeg.read().await; + stream.write_all(&jpeg).await?; } } @@ -109,21 +157,18 @@ async fn handle_client( /// <jpeg_data>... /// ...ad infinitum /// ``` -async fn handle_server<A>(addr: A, state: Arc<State>) -> Result<(), Error> -where - A: ToSocketAddrs, -{ +async fn handle_server(state: Arc<State>) -> Result<(), Error> { loop { println!("Waiting for client..."); let mut header_buf = state.headers.write().await; header_buf.clear(); - state.client_req_notify.notified().await; - if state.client_req.load(Ordering::Relaxed) == 0 { + state.client_ready.notified().await; + if state.clients.load(Ordering::Relaxed) == 0 { continue; } - let mut stream = TcpStream::connect(&addr).await?; - println!("Connected to: {}", stream.local_addr()?); + let mut stream = TcpStream::connect(&state.remote_addr).await?; + println!("Connected to: {}", state.remote_addr); let req = b"GET /mjpeg/1 HTTP/1.1\r\nUser-Agent: MjpegProxy/0.1.0\r\nAccept: */*\r\n\r\n"; stream.write_all(req).await?; @@ -166,7 +211,7 @@ where let mut part_buf = Vec::new(); let mut data_buf = Vec::with_capacity(16_384); - while state.client_req.load(Ordering::Relaxed) > 0 { + while state.clients.load(Ordering::Relaxed) > 0 { let mut pos = buf_reader.read_until(b'\n', &mut part_buf).await?; assert_eq!(part_buf[2..pos - 2], *boundary); @@ -197,11 +242,15 @@ where buf_reader.read_exact(&mut data_buf).await?; assert_eq!(data_buf[len..], *b"\r\n"); - let mut frame_buf = state.frame_buf.write().await; - frame_buf.clear(); - frame_buf.extend_from_slice(&part_buf); - frame_buf.extend_from_slice(&data_buf); - state.frame_ready.notify_waiters(); + let mut part = state.part.write().await; + part.clear(); + part.extend_from_slice(&part_buf); + + let mut jpeg = state.jpeg.write().await; + jpeg.clear(); + jpeg.extend_from_slice(&data_buf); + + state.part_ready.notify_waiters(); part_buf.clear(); data_buf.clear(); @@ -225,6 +274,9 @@ enum Error { #[error("Missing Content-Length header")] MissingContentLength, + #[error("Invalid route in Host header")] + InvalidRoute, + #[error("Invalid Content-Length header: {0}")] InvalidContentLength(String), |