diff --git a/src/main.rs b/src/main.rs index ad5a484..36a30f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::{time::Duration}; use crossbeam_channel::unbounded; -use modules::{socket_handler::Socket, stream_states::stream_states_class::StreamState, message_handler::{MessageHandler, StateMessage}}; +use modules::{socket_handler::Socket, stream_states::stream_states_class::StreamState, message_handler::{MessageHandler}}; use workctl::sync_flag; use crate::modules::stream_states::state_update::StateUpdate; @@ -19,7 +19,7 @@ fn main() { let socket_listener = Socket::make_listener(SERVER_ADDRESS); let (from_socket_tx, from_socket_rx) = unbounded::(); let (to_socket_tx, to_socket_rx) = unbounded::(); - let (mut listener_can_run_flag, listener_join_handle) = Socket::handle_connections(socket_listener, from_socket_tx, to_socket_rx); + let mut socket = Socket::handle_connections(socket_listener, from_socket_tx, to_socket_rx); let (control_c_flag_tx, control_c_called_flag_rx) = sync_flag::new_syncflag(false); @@ -40,9 +40,7 @@ fn main() { } } - //Close the listener thread - listener_can_run_flag.set(false); - listener_join_handle.join().unwrap(); + socket.close(); } fn setup_control_c(mut control_c_flag_tx: sync_flag::SyncFlagTx) { diff --git a/src/modules/socket_handler.rs b/src/modules/socket_handler.rs index 4536c72..dfe0674 100644 --- a/src/modules/socket_handler.rs +++ b/src/modules/socket_handler.rs @@ -6,7 +6,9 @@ use std::thread::{self, JoinHandle}; use std::time::Duration; pub struct Socket{ - + socket_txs: Vec>, + stop_listener_flag: sync_flag::SyncFlagTx, + handle_connections_join_handle: Option>, } impl Socket { @@ -15,7 +17,7 @@ impl Socket { TcpListener::bind(address).unwrap() } - pub fn handle_connections(listener: TcpListener, messenger_tx: Sender, messenger_rx: Receiver) -> (sync_flag::SyncFlagTx, JoinHandle<()>){ + pub fn handle_connections(listener: TcpListener, messenger_tx: Sender, messenger_rx: Receiver) -> Self { let (tx, thread_stop_flag) = sync_flag::new_syncflag(true); let handle = thread::spawn(move || { @@ -28,7 +30,11 @@ impl Socket { } drop(listener); }); - (tx, handle) + Socket { + socket_txs: Vec::>::new(), + stop_listener_flag: tx, + handle_connections_join_handle: Some(handle), + } } pub fn handle_client(mut stream: TcpStream, update_tx: Sender, message_rx: Receiver, program_shutdown_flag: sync_flag::SyncFlagRx) { @@ -58,4 +64,11 @@ impl Socket { } stream.shutdown(Shutdown::Both).unwrap(); } + + pub fn close(&mut self) { + self.stop_listener_flag.set(false); + self.handle_connections_join_handle + .take().expect("Called on not running thread") + .join().expect("Could not join thread"); + } } \ No newline at end of file diff --git a/src/tests/socket_handler_tests.rs b/src/tests/socket_handler_tests.rs index 53d5f1b..abb3b82 100644 --- a/src/tests/socket_handler_tests.rs +++ b/src/tests/socket_handler_tests.rs @@ -1,4 +1,3 @@ -use std::sync::mpsc; use crossbeam_channel::unbounded; use std::io::{Write, Read}; use std::thread; @@ -35,7 +34,7 @@ fn can_handle_messages() { let (tx_1, rx_1) = unbounded::(); let (_stream_tx, stream_rx) = unbounded::(); - let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); + let mut socket = Socket::handle_connections(listener, tx_1, stream_rx); let join_handle = std::thread::spawn(move || { let mut outgoing = std::net::TcpStream::connect("localhost:5004").unwrap(); @@ -45,10 +44,10 @@ fn can_handle_messages() { join_handle.join().unwrap(); thread::sleep(Duration::from_millis(1000)); - flag.set(false); - connection_handle.join().unwrap(); let message = rx_1.recv().unwrap(); assert_eq!(message, String::from("this is a test")); + + socket.close(); } #[test] @@ -57,7 +56,7 @@ fn can_handle_delayed_messages() { let (tx_1, rx_1) = unbounded::(); let (_stream_tx, stream_rx) = unbounded::(); - let (mut flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); + let mut socket = Socket::handle_connections(listener, tx_1, stream_rx); let mut outgoing = std::net::TcpStream::connect("localhost:5005").unwrap(); outgoing.write("this is a test1\n".as_bytes()).unwrap(); @@ -74,8 +73,7 @@ fn can_handle_delayed_messages() { println!("{}", message); assert_eq!(message, String::from("this is a test3\n")); - flag.set(false); - connection_handle.join().unwrap(); + socket.close(); } #[test] @@ -84,7 +82,7 @@ fn can_send_and_receive_on_stream() { let (tx_1, rx_1) = unbounded::(); let (stream_tx, stream_rx) = unbounded::(); - let (mut close_socket_flag, connection_handle) = Socket::handle_connections(listener, tx_1, stream_rx); + let mut socket = Socket::handle_connections(listener, tx_1, stream_rx); let mut outgoing = std::net::TcpStream::connect("localhost:5006").unwrap(); outgoing.set_read_timeout(Some(Duration::from_millis(1000))).expect("couln't set timout"); @@ -105,6 +103,5 @@ fn can_send_and_receive_on_stream() { assert_eq!("this is another test!", message.into_owned()); drop(outgoing); - close_socket_flag.set(false); - connection_handle.join().unwrap(); + socket.close(); }