diff --git a/src/ipc/server.rs b/src/ipc/server.rs index ce3f78d5..095310df 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -31,11 +31,10 @@ where let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); let data = Arc::new(RwLock::new(self.data)); - tracing::info!("address = {}", address.to_string()); + tracing::info!("address = {:?}", address); while let Ok((stream, remote_address)) = listener.protocol_accept().await { - let remote_address = remote_address.to_string(); - tracing::debug!("remote_address = {}", remote_address); + tracing::debug!("remote_address = {:?}", remote_address); let handler = Arc::clone(&handler); let namespaces = Arc::clone(&namespaces); let data = Arc::clone(&data); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 55a95e39..4837ac0f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,14 +1,17 @@ pub mod tcp; +#[cfg(unix)] +pub mod unix_socket; + use crate::prelude::IPCResult; use async_trait::async_trait; use std::fmt::Debug; use tokio::io::{AsyncRead, AsyncWrite}; #[async_trait] -pub trait AsyncStreamProtocolListener: Sized { - type AddressType: ToString + Clone + Debug; - type RemoteAddressType: ToString; +pub trait AsyncStreamProtocolListener: Sized + Send + Sync { + type AddressType: Clone + Debug + Send + Sync; + type RemoteAddressType: Debug; type Stream: 'static + AsyncProtocolStream; async fn protocol_bind(address: Self::AddressType) -> IPCResult; @@ -25,9 +28,9 @@ pub trait AsyncProtocolStreamSplit { #[async_trait] pub trait AsyncProtocolStream: - AsyncRead + AsyncWrite + Sized + Send + Sync + AsyncProtocolStreamSplit + AsyncRead + AsyncWrite + Send + Sync + AsyncProtocolStreamSplit + Sized { - type AddressType: ToString + Clone + Debug; + type AddressType: Clone + Debug + Send + Sync; async fn protocol_connect(address: Self::AddressType) -> IPCResult; } diff --git a/src/protocol/unix_socket.rs b/src/protocol/unix_socket.rs new file mode 100644 index 00000000..91809083 --- /dev/null +++ b/src/protocol/unix_socket.rs @@ -0,0 +1,51 @@ +use crate::error::Result; +use crate::prelude::IPCResult; +use crate::protocol::{AsyncProtocolStream, AsyncProtocolStreamSplit, AsyncStreamProtocolListener}; +use async_trait::async_trait; +use std::path::PathBuf; +use tokio::io::Interest; +use tokio::net::unix::OwnedWriteHalf; +use tokio::net::unix::{OwnedReadHalf, SocketAddr}; +use tokio::net::{UnixListener, UnixStream}; + +#[async_trait] +impl AsyncStreamProtocolListener for UnixListener { + type AddressType = PathBuf; + type RemoteAddressType = SocketAddr; + type Stream = UnixStream; + + async fn protocol_bind(address: Self::AddressType) -> Result { + let listener = UnixListener::bind(address)?; + + Ok(listener) + } + + async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> { + let connection = self.accept().await?; + + Ok(connection) + } +} + +impl AsyncProtocolStreamSplit for UnixStream { + type OwnedSplitReadHalf = OwnedReadHalf; + type OwnedSplitWriteHalf = OwnedWriteHalf; + + fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) { + self.into_split() + } +} + +#[async_trait] +impl AsyncProtocolStream for UnixStream { + type AddressType = PathBuf; + + async fn protocol_connect(address: Self::AddressType) -> IPCResult { + let stream = UnixStream::connect(address).await?; + stream + .ready(Interest::READABLE | Interest::WRITABLE) + .await?; + + Ok(stream) + } +} diff --git a/src/tests/ipc_tests.rs b/src/tests/ipc_tests.rs index 1a484fbe..3760f644 100644 --- a/src/tests/ipc_tests.rs +++ b/src/tests/ipc_tests.rs @@ -3,10 +3,11 @@ use crate::prelude::*; use crate::protocol::AsyncProtocolStream; use crate::tests::utils::start_test_server; use std::net::ToSocketAddrs; +use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, UnixListener}; use typemap_rev::TypeMapKey; async fn handle_ping_event(ctx: &Context

, e: Event) -> IPCResult<()> { @@ -21,19 +22,35 @@ async fn handle_ping_event(ctx: &Context

, e: Event) - Ok(()) } -fn get_builder_with_ping(address: &str) -> IPCBuilder { +fn get_builder_with_ping(address: L::AddressType) -> IPCBuilder { IPCBuilder::new() .on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e))) - .address(address.to_socket_addrs().unwrap().next().unwrap()) + .address(address) +} + +#[tokio::test] +async fn it_receives_tcp_events() { + let socket_address = "127.0.0.1:8281".to_socket_addrs().unwrap().next().unwrap(); + it_receives_events::(socket_address).await; } +#[cfg(unix)] #[tokio::test] -async fn it_receives_events() { - let builder = get_builder_with_ping("127.0.0.1:8281"); +async fn it_receives_unix_socket_events() { + let socket_path = PathBuf::from("/tmp/test_socket"); + if socket_path.exists() { + std::fs::remove_file(&socket_path).unwrap(); + } + it_receives_events::(socket_path).await; +} + +async fn it_receives_events(address: L::AddressType) { + let builder = get_builder_with_ping::(address.clone()); let server_running = Arc::new(AtomicBool::new(false)); + tokio::spawn({ let server_running = Arc::clone(&server_running); - let builder = get_builder_with_ping("127.0.0.1:8281"); + let builder = get_builder_with_ping::(address); async move { server_running.store(true, Ordering::SeqCst); builder.build_server().await.unwrap();