diff --git a/src/main.rs b/src/main.rs index ab77bc8..53a3a15 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,6 @@ use std::sync::Arc; use structopt::StructOpt; use tokio::net::UdpSocket; use tokio::net::{TcpListener, TcpStream}; -use tokio::select; use tokio_postgres::{Client, Config, NoTls, Statement}; // todo: do better errorhandling @@ -79,41 +78,70 @@ async fn main() -> Result<(), Error> { let blacklist: HashSet<_> = options.blacklist.into_iter().collect(); let blacklist = Arc::new(blacklist); - // rfc says max length for messages is 1024 - let mut buf = [0; 1024]; + // tcp handling + let tcp = { + let client = client.clone(); + let insert_statement = insert_statement.clone(); + let blacklist = blacklist.clone(); - loop { - //todo: possibly better implemented by just running two tokio::spawn tasks that loop - select! { - tcp = tcp_listener.accept() => match tcp { - Ok((socket, peer)) => { - tokio::spawn(handle_peer_and_error( - socket, - peer, - client.clone(), - insert_statement.clone(), - blacklist.clone(), - )); - } - Err(e) => eprintln!("tcp error: {:?}", e), - }, - udp = udp_socket.recv_from(&mut buf) => match udp { - Ok((len, addr)) => { - let line = &buf[0..len]; - let line = match std::str::from_utf8(&line) { - Ok(l) => l, - Err(e) => {eprintln!("udp packet is not valid utf8: {e}"); continue}, - }; - let line: String = line.into(); + tokio::spawn(async move { + loop { + match tcp_listener.accept().await { + Ok((socket, peer)) => { + handle_peer_and_error( + socket, + peer, + client.clone(), + insert_statement.clone(), + blacklist.clone(), + ) + .await + } + Err(e) => eprintln!("tcp error: {:?}", e), + }; + } + }) + }; - tokio::spawn(handle_udp_and_error( - line, addr, client.clone(), insert_statement.clone(),blacklist.clone())); + // udp handling + let udp = { + let client = client.clone(); + let insert_statement = insert_statement.clone(); + let blacklist = blacklist.clone(); + tokio::spawn(async move { + // rfc says max length for messages is 1024 + let mut buf = [0; 1024]; + loop { + match udp_socket.recv_from(&mut buf).await { + Ok((len, addr)) => { + let line = &buf[0..len]; + let line = match std::str::from_utf8(&line) { + Ok(l) => l, + Err(e) => { + eprintln!("udp packet is not valid utf8: {e}"); + continue; + } + }; + let line: String = line.into(); - }, - Err(e) => eprintln!("udp error: {:?}", e), - }, - } - } + handle_udp_and_error( + line, + addr, + client.clone(), + insert_statement.clone(), + blacklist.clone(), + ) + .await + } + Err(e) => eprintln!("udp error: {:?}", e), + }; + } + }) + }; + tcp.await?; + udp.await?; + // should be unreachable + Ok(()) } async fn handle_udp_and_error>(