diff --git a/src/main.rs b/src/main.rs index 6cc9f92..ad5a484 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{sync::mpsc, time::Duration}; +use std::{time::Duration}; use crossbeam_channel::unbounded; use modules::{socket_handler::Socket, stream_states::stream_states_class::StreamState, message_handler::{MessageHandler, StateMessage}}; @@ -18,12 +18,14 @@ fn main() { let socket_listener = Socket::make_listener(SERVER_ADDRESS); let (from_socket_tx, from_socket_rx) = unbounded::(); - let (mut listener_can_run_flag, listener_join_handle) = Socket::handle_connections(socket_listener, from_socket_tx); + let (to_socket_tx, to_socket_rx) = unbounded::(); + let (mut listener_can_run_flag, listener_join_handle) = Socket::handle_connections(socket_listener, from_socket_tx, to_socket_rx); let (control_c_flag_tx, control_c_called_flag_rx) = sync_flag::new_syncflag(false); setup_control_c(control_c_flag_tx); - + let _outgoing = std::net::TcpStream::connect(SERVER_ADDRESS).unwrap(); + to_socket_tx.send("this is a message".to_string()).unwrap(); //until control_c is caught, check the queue of incoming //requests from the socket handler. while !control_c_called_flag_rx.get() { diff --git a/src/modules/socket_handler.rs b/src/modules/socket_handler.rs index b233ae4..4536c72 100644 --- a/src/modules/socket_handler.rs +++ b/src/modules/socket_handler.rs @@ -1,7 +1,7 @@ use workctl::sync_flag; use std::net::{TcpListener, TcpStream, Shutdown}; use std::io::{Read, Write}; -use crossbeam_channel::Sender; +use crossbeam_channel::{Sender, Receiver}; use std::thread::{self, JoinHandle}; use std::time::Duration; @@ -15,14 +15,14 @@ impl Socket { TcpListener::bind(address).unwrap() } - pub fn handle_connections(listener: TcpListener, messenger_tx: Sender) -> (sync_flag::SyncFlagTx, JoinHandle<()>){ + pub fn handle_connections(listener: TcpListener, messenger_tx: Sender, messenger_rx: Receiver) -> (sync_flag::SyncFlagTx, JoinHandle<()>){ let (tx, thread_stop_flag) = sync_flag::new_syncflag(true); let handle = thread::spawn(move || { listener.set_nonblocking(true).unwrap(); while thread_stop_flag.get() { for (stream, _addr) in listener.accept() { - Socket::handle_client(stream, messenger_tx.clone(), thread_stop_flag.clone()); + Socket::handle_client(stream, messenger_tx.clone(), messenger_rx.clone(), thread_stop_flag.clone()); } thread::sleep(Duration::from_millis(100)); } @@ -31,12 +31,12 @@ impl Socket { (tx, handle) } - pub fn handle_client(mut stream: TcpStream, update_tx: Sender, program_shutdown_flag: sync_flag::SyncFlagRx) { + pub fn handle_client(mut stream: TcpStream, update_tx: Sender, message_rx: Receiver, program_shutdown_flag: sync_flag::SyncFlagRx) { let mut buffer = [0; 1024]; stream.set_read_timeout(Some(Duration::from_millis(100))).expect("Could not set a read timeout"); while program_shutdown_flag.get() { match stream.read(&mut buffer) { - Err(_) => {continue}, + Err(_) => {}, Ok(read_size) => { //Tcp is supposed to have a 0 byte read if closed by client if read_size == 0 || !program_shutdown_flag.get() { @@ -48,6 +48,13 @@ impl Socket { } } } + match message_rx.try_recv() { + Err(_) => {}, + Ok(message) => { + stream.write(message.as_bytes()).unwrap(); + stream.flush().unwrap(); + } + } } stream.shutdown(Shutdown::Both).unwrap(); } diff --git a/src/tests/socket_handler_tests.rs b/src/tests/socket_handler_tests.rs index 2aa2497..53d5f1b 100644 --- a/src/tests/socket_handler_tests.rs +++ b/src/tests/socket_handler_tests.rs @@ -1,6 +1,6 @@ use std::sync::mpsc; use crossbeam_channel::unbounded; -use std::io::{Write}; +use std::io::{Write, Read}; use std::thread; use std::time::Duration; @@ -33,8 +33,9 @@ fn panic_no_listener() { fn can_handle_messages() { let listener = Socket::make_listener("localhost:5004"); let (tx_1, rx_1) = unbounded::(); + let (_stream_tx, stream_rx) = unbounded::(); - let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1); + let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); let join_handle = std::thread::spawn(move || { let mut outgoing = std::net::TcpStream::connect("localhost:5004").unwrap(); @@ -54,18 +55,16 @@ fn can_handle_messages() { fn can_handle_delayed_messages() { let listener = Socket::make_listener("localhost:5005"); let (tx_1, rx_1) = unbounded::(); + let (_stream_tx, stream_rx) = unbounded::(); - let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1); + let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); - let join_handle = std::thread::spawn(move || { - let mut outgoing = std::net::TcpStream::connect("localhost:5005").unwrap(); - outgoing.write("this is a test1\n".as_bytes()).unwrap(); - thread::sleep(Duration::from_millis(500)); - outgoing.write("this is a test3\n".as_bytes()).unwrap(); - drop(outgoing); - }); - join_handle.join().unwrap(); - thread::sleep(Duration::from_millis(1000)); + let mut outgoing = std::net::TcpStream::connect("localhost:5005").unwrap(); + outgoing.write("this is a test1\n".as_bytes()).unwrap(); + thread::sleep(Duration::from_millis(500)); + outgoing.write("this is a test3\n".as_bytes()).unwrap(); + drop(outgoing); + thread::sleep(Duration::from_millis(500)); let message = rx_1.recv().unwrap(); println!("{}", message); @@ -78,3 +77,34 @@ fn can_handle_delayed_messages() { flag.set(false); connection_handle.join().unwrap(); } + +#[test] +fn can_send_and_receive_on_stream() { + let listener = Socket::make_listener("localhost:5006"); + let (tx_1, rx_1) = unbounded::(); + let (stream_tx, stream_rx) = unbounded::(); + + let (mut close_socket_flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); + + let mut outgoing = std::net::TcpStream::connect("localhost:5006").unwrap(); + outgoing.set_read_timeout(Some(Duration::from_millis(1000))).expect("couln't set timout"); + + outgoing.write("such a test!\n".as_bytes()).unwrap(); + outgoing.flush().unwrap(); + thread::sleep(Duration::from_millis(250)); + assert_eq!(rx_1.try_recv().unwrap(), "such a test!\n"); + + stream_tx.send("this is another test!".to_string()).unwrap(); + thread::sleep(Duration::from_millis(250)); + + let mut buffer = [0; 256]; + let msg_len = outgoing.read(&mut buffer).unwrap(); + assert!(msg_len != 0); + + let message = String::from_utf8_lossy(&buffer[0..msg_len]); + assert_eq!("this is another test!", message.into_owned()); + + drop(outgoing); + close_socket_flag.set(false); + connection_handle.join().unwrap(); +}