Merge pull request #24 from Trivernis/develop

Remove useless generic bounds
pull/32/head
Julius Riegel 2 years ago committed by GitHub
commit 332461ac7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

2
Cargo.lock generated

@ -38,7 +38,7 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bromine"
version = "0.11.0"
version = "0.13.0"
dependencies = [
"async-trait",
"byteorder",

@ -1,6 +1,6 @@
[package]
name = "bromine"
version = "0.11.0"
version = "0.13.0"
authors = ["trivernis <trivernis@protonmail.com>"]
edition = "2018"
readme = "README.md"
@ -21,26 +21,38 @@ harness = false
[dependencies]
thiserror = "1.0.30"
rmp-serde = "0.15.4"
tracing = "0.1.29"
lazy_static = "1.4.0"
typemap_rev = "0.1.5"
byteorder = "1.4.3"
async-trait = "0.1.51"
futures = "0.3.17"
rmp-serde = {version = "0.15.5", optional = true}
[dependencies.serde]
optional = true
version = "1.0.130"
features = ["serde_derive"]
features = []
[dependencies.tokio]
version = "1.12.0"
features = ["net", "io-std", "io-util", "sync", "time"]
[dev-dependencies]
rmp-serde = "0.15.4"
[dev-dependencies.serde]
version = "1.0.130"
features = ["serde_derive"]
[dev-dependencies.criterion]
version = "0.3.5"
features = ["async_tokio", "html_reports"]
[dev-dependencies.tokio]
version = "1.12.0"
features = ["macros", "rt-multi-thread"]
features = ["macros", "rt-multi-thread"]
[features]
default = ["messagepack"]
messagepack = ["serde", "rmp-serde"]

@ -16,7 +16,7 @@ use bromine::prelude::*;
use tokio::net::TcpListener;
/// Callback ping function
async fn handle_ping<S: AsyncProtocolStream>(ctx: &Context<S>, event: Event) -> Result<()> {
async fn handle_ping(ctx: &Context, event: Event) -> Result<()> {
println!("Received ping event.");
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
Ok(())
@ -95,7 +95,7 @@ use tokio::net::TcpListener;
pub struct MyNamespace;
impl MyNamespace {
async fn ping<S: AsyncProtocolStream>(_ctx: &Context<S>, _event: Event) -> Result<()> {
async fn ping(_ctx: &Context, _event: Event) -> Result<()> {
println!("My namespace received a ping");
Ok(())
}
@ -104,7 +104,7 @@ impl MyNamespace {
impl NamespaceProvider for MyNamespace {
fn name() -> &'static str {"my_namespace"}
fn register<S: AsyncProtocolStream>(handler: &mut EventHandler<S>) {
fn register(handler: &mut EventHandler) {
events!(handler,
"ping" => Self::ping
);

@ -3,7 +3,7 @@ use criterion::{criterion_group, criterion_main};
use criterion::{BatchSize, Criterion};
use std::io::Cursor;
use rmp_ipc::event::Event;
use bromine::event::Event;
use tokio::runtime::Runtime;
pub const EVENT_NAME: &str = "bench_event";

@ -1,7 +1,7 @@
use bromine::event::Event;
use criterion::{
black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput,
};
use rmp_ipc::event::Event;
pub const EVENT_NAME: &str = "bench_event";

@ -9,9 +9,11 @@ pub enum Error {
#[error(transparent)]
IoError(#[from] tokio::io::Error),
#[cfg(feature = "messagepack")]
#[error(transparent)]
Decode(#[from] rmp_serde::decode::Error),
#[cfg(feature = "messagepack")]
#[error(transparent)]
Encode(#[from] rmp_serde::encode::Error),
@ -24,6 +26,9 @@ pub enum Error {
#[error("Channel Error: {0}")]
ReceiveError(#[from] oneshot::error::RecvError),
#[error("The received event was corrupted")]
CorruptedEvent,
#[error("Send Error")]
SendError,

@ -1,6 +1,10 @@
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::payload::{EventReceivePayload, EventSendPayload};
use crate::prelude::{IPCError, IPCResult};
use byteorder::{BigEndian, ReadBytesExt};
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::io::Read;
pub static ERROR_EVENT_NAME: &str = "error";
@ -8,7 +12,7 @@ pub static ERROR_EVENT_NAME: &str = "error";
/// The error event has a default handler that just logs that
/// an error occurred. For a custom handler, register a handler on
/// the [ERROR_EVENT_NAME] event.
#[derive(Clone, Deserialize, Serialize, Debug)]
#[derive(Clone, Debug)]
pub struct ErrorEventData {
pub code: u16,
pub message: String,
@ -21,3 +25,27 @@ impl Display for ErrorEventData {
write!(f, "IPC Code {}: '{}'", self.code, self.message)
}
}
impl EventSendPayload for ErrorEventData {
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let mut buf = Vec::new();
buf.append(&mut self.code.to_be_bytes().to_vec());
let message_len = self.message.len() as u32;
buf.append(&mut message_len.to_be_bytes().to_vec());
buf.append(&mut self.message.into_bytes());
Ok(buf)
}
}
impl EventReceivePayload for ErrorEventData {
fn from_payload_bytes<R: Read>(mut reader: R) -> Result<Self> {
let code = reader.read_u16::<BigEndian>()?;
let message_len = reader.read_u32::<BigEndian>()?;
let mut message_buf = vec![0u8; message_len as usize];
reader.read_exact(&mut message_buf)?;
let message = String::from_utf8(message_buf).map_err(|_| IPCError::CorruptedEvent)?;
Ok(ErrorEventData { code, message })
}
}

@ -1,7 +1,6 @@
use crate::error::Result;
use crate::error::{Error, Result};
use crate::events::generate_event_id;
use crate::events::payload::EventReceivePayload;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use tokio::io::{AsyncRead, AsyncReadExt};
@ -14,7 +13,7 @@ pub struct Event {
data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug)]
struct EventHeader {
id: u64,
ref_id: Option<u64>,
@ -94,11 +93,8 @@ impl Event {
let data_length = total_length - header_length as u64;
tracing::trace!(total_length, header_length, data_length);
let header: EventHeader = {
let mut header_bytes = vec![0u8; header_length as usize];
reader.read_exact(&mut header_bytes).await?;
rmp_serde::from_read(&header_bytes[..])?
};
let header: EventHeader = EventHeader::from_async_read(reader).await?;
let mut data = vec![0u8; data_length as usize];
reader.read_exact(&mut data).await?;
let event = Event { header, data };
@ -109,7 +105,7 @@ impl Event {
/// Encodes the event into bytes
#[tracing::instrument(level = "trace", skip(self))]
pub fn into_bytes(mut self) -> Result<Vec<u8>> {
let mut header_bytes = rmp_serde::to_vec(&self.header)?;
let mut header_bytes = self.header.into_bytes();
let header_length = header_bytes.len() as u16;
let data_length = self.data.len();
let total_length = header_length as u64 + data_length as u64;
@ -124,3 +120,61 @@ impl Event {
Ok(buf)
}
}
impl EventHeader {
/// Serializes the event header into bytes
pub fn into_bytes(self) -> Vec<u8> {
let mut buf = Vec::new();
buf.append(&mut self.id.to_be_bytes().to_vec());
if let Some(ref_id) = self.ref_id {
buf.push(0xFF);
buf.append(&mut ref_id.to_be_bytes().to_vec());
} else {
buf.push(0x00);
}
if let Some(namespace) = self.namespace {
let namespace_len = namespace.len() as u16;
buf.append(&mut namespace_len.to_be_bytes().to_vec());
buf.append(&mut namespace.into_bytes());
} else {
buf.append(&mut 0u16.to_be_bytes().to_vec());
}
let name_len = self.name.len() as u16;
buf.append(&mut name_len.to_be_bytes().to_vec());
buf.append(&mut self.name.into_bytes());
buf
}
/// Parses an event header from an async reader
pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self> {
let id = reader.read_u64().await?;
let ref_id_exists = reader.read_u8().await?;
let ref_id = match ref_id_exists {
0x00 => None,
0xFF => Some(reader.read_u64().await?),
_ => return Err(Error::CorruptedEvent),
};
let namespace_len = reader.read_u16().await?;
let namespace = if namespace_len > 0 {
let mut namespace_buf = vec![0u8; namespace_len as usize];
reader.read_exact(&mut namespace_buf).await?;
Some(String::from_utf8(namespace_buf).map_err(|_| Error::CorruptedEvent)?)
} else {
None
};
let name_len = reader.read_u16().await?;
let mut name_buf = vec![0u8; name_len as usize];
reader.read_exact(&mut name_buf).await?;
let name = String::from_utf8(name_buf).map_err(|_| Error::CorruptedEvent)?;
Ok(Self {
id,
ref_id,
namespace,
name,
})
}
}

@ -1,39 +1,25 @@
use crate::error::Result;
use crate::events::event::Event;
use crate::ipc::context::Context;
use crate::protocol::AsyncProtocolStream;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
type EventCallback<P> = Arc<
dyn for<'a> Fn(&'a Context<P>, Event) -> Pin<Box<(dyn Future<Output = Result<()>> + Send + 'a)>>
type EventCallback = Arc<
dyn for<'a> Fn(&'a Context, Event) -> Pin<Box<(dyn Future<Output = Result<()>> + Send + 'a)>>
+ Send
+ Sync,
>;
/// Handler for events
pub struct EventHandler<P: AsyncProtocolStream> {
callbacks: HashMap<String, EventCallback<P>>,
#[derive(Clone)]
pub struct EventHandler {
callbacks: HashMap<String, EventCallback>,
}
impl<S> Clone for EventHandler<S>
where
S: AsyncProtocolStream,
{
fn clone(&self) -> Self {
Self {
callbacks: self.callbacks.clone(),
}
}
}
impl<P> Debug for EventHandler<P>
where
P: AsyncProtocolStream,
{
impl Debug for EventHandler {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let callback_names: String = self
.callbacks
@ -45,10 +31,7 @@ where
}
}
impl<P> EventHandler<P>
where
P: AsyncProtocolStream,
{
impl EventHandler {
/// Creates a new event handler
pub fn new() -> Self {
Self {
@ -61,7 +44,7 @@ where
pub fn on<F: 'static>(&mut self, name: &str, callback: F)
where
F: for<'a> Fn(
&'a Context<P>,
&'a Context,
Event,
) -> Pin<Box<(dyn Future<Output = Result<()>> + Send + 'a)>>
+ Send
@ -72,7 +55,7 @@ where
/// Handles a received event
#[tracing::instrument(level = "debug", skip(self, ctx, event))]
pub async fn handle_event(&self, ctx: &Context<P>, event: Event) -> Result<()> {
pub async fn handle_event(&self, ctx: &Context, event: Event) -> Result<()> {
if let Some(cb) = self.callbacks.get(event.name()) {
cb.as_ref()(ctx, event).await?;
}

@ -1,7 +1,5 @@
use crate::prelude::IPCResult;
use byteorder::{BigEndian, ReadBytesExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
/// Trait to convert event data into sending bytes
@ -10,33 +8,12 @@ pub trait EventSendPayload {
fn to_payload_bytes(self) -> IPCResult<Vec<u8>>;
}
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = rmp_serde::to_vec(&self)?;
Ok(bytes)
}
}
/// Trait to get the event data from receiving bytes.
/// It is implemented for all types that are DeserializeOwned
pub trait EventReceivePayload: Sized {
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self>;
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = rmp_serde::from_read(reader)?;
Ok(type_data)
}
}
/// A payload wrapper type for sending bytes directly without
/// serializing them
#[derive(Clone)]
@ -132,3 +109,36 @@ where
})
}
}
#[cfg(feature = "messagepack")]
mod rmp_impl {
use super::{EventReceivePayload, EventSendPayload};
use crate::prelude::IPCResult;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = rmp_serde::to_vec(&self)?;
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = rmp_serde::from_read(reader)?;
Ok(type_data)
}
}
}
#[cfg(feature = "messagepack")]
pub use rmp_impl::*;

@ -50,10 +50,11 @@ use typemap_rev::{TypeMap, TypeMapKey};
/// .build_server().await.unwrap();
///# }
/// ```
///
pub struct IPCBuilder<L: AsyncStreamProtocolListener> {
handler: EventHandler<L::Stream>,
handler: EventHandler,
address: Option<L::AddressType>,
namespaces: HashMap<String, Namespace<L::Stream>>,
namespaces: HashMap<String, Namespace>,
data: TypeMap,
timeout: Duration,
}
@ -93,7 +94,7 @@ where
pub fn on<F: 'static>(mut self, event: &str, callback: F) -> Self
where
F: for<'a> Fn(
&'a Context<L::Stream>,
&'a Context,
Event,
) -> Pin<Box<(dyn Future<Output = Result<()>> + Send + 'a)>>
+ Send
@ -117,7 +118,7 @@ where
}
/// Adds a namespace to the ipc server
pub fn add_namespace(mut self, namespace: Namespace<L::Stream>) -> Self {
pub fn add_namespace(mut self, namespace: Namespace) -> Self {
self.namespaces
.insert(namespace.name().to_owned(), namespace);
@ -135,20 +136,20 @@ where
#[tracing::instrument(skip(self))]
pub async fn build_server(self) -> Result<()> {
self.validate()?;
let server = IPCServer::<L> {
let server = IPCServer {
namespaces: self.namespaces,
handler: self.handler,
data: self.data,
timeout: self.timeout,
};
server.start(self.address.unwrap()).await?;
server.start::<L>(self.address.unwrap()).await?;
Ok(())
}
/// Builds an ipc client
#[tracing::instrument(skip(self))]
pub async fn build_client(self) -> Result<Context<L::Stream>> {
pub async fn build_client(self) -> Result<Context> {
self.validate()?;
let data = Arc::new(RwLock::new(self.data));
let reply_listeners = ReplyListeners::default();
@ -160,7 +161,7 @@ where
timeout: self.timeout,
};
let ctx = client.connect(self.address.unwrap()).await?;
let ctx = client.connect::<L::Stream>(self.address.unwrap()).await?;
Ok(ctx)
}
@ -170,7 +171,7 @@ where
/// return a [crate::context::PooledContext] that allows one to [crate::context::PooledContext::acquire] a single context
/// to emit events.
#[tracing::instrument(skip(self))]
pub async fn build_pooled_client(self, pool_size: usize) -> Result<PooledContext<L::Stream>> {
pub async fn build_pooled_client(self, pool_size: usize) -> Result<PooledContext> {
if pool_size == 0 {
Error::BuildError("Pool size must be greater than 0".to_string());
}
@ -189,7 +190,7 @@ where
timeout: self.timeout.clone(),
};
let ctx = client.connect(address.clone()).await?;
let ctx = client.connect::<L::Stream>(address.clone()).await?;
contexts.push(ctx);
}

@ -16,25 +16,25 @@ use typemap_rev::TypeMap;
/// Use the [IPCBuilder](crate::builder::IPCBuilder) to create the client.
/// Usually one does not need to use the IPCClient object directly.
#[derive(Clone)]
pub struct IPCClient<S: AsyncProtocolStream> {
pub(crate) handler: EventHandler<S>,
pub(crate) namespaces: HashMap<String, Namespace<S>>,
pub struct IPCClient {
pub(crate) handler: EventHandler,
pub(crate) namespaces: HashMap<String, Namespace>,
pub(crate) data: Arc<RwLock<TypeMap>>,
pub(crate) reply_listeners: ReplyListeners,
pub(crate) timeout: Duration,
}
impl<S> IPCClient<S>
where
S: 'static + AsyncProtocolStream,
{
impl IPCClient {
/// Connects to a given address and returns an emitter for events to that address.
/// Invoked by [IPCBuilder::build_client](crate::builder::IPCBuilder::build_client)
#[tracing::instrument(skip(self))]
pub async fn connect(self, address: S::AddressType) -> Result<Context<S>> {
pub async fn connect<S: AsyncProtocolStream + 'static>(
self,
address: S::AddressType,
) -> Result<Context> {
let stream = S::protocol_connect(address).await?;
let (read_half, write_half) = stream.protocol_into_split();
let emitter = StreamEmitter::new(write_half);
let emitter = StreamEmitter::new::<S>(write_half);
let (tx, rx) = oneshot::channel();
let ctx = Context::new(
StreamEmitter::clone(&emitter),
@ -49,7 +49,7 @@ where
let handle = tokio::spawn({
let ctx = Context::clone(&ctx);
async move {
handle_connection(namespaces, handler, read_half, ctx).await;
handle_connection::<S>(namespaces, handler, read_half, ctx).await;
}
});
tokio::spawn(async move {

@ -1,7 +1,6 @@
use crate::error::{Error, Result};
use crate::event::Event;
use crate::ipc::stream_emitter::StreamEmitter;
use crate::protocol::AsyncProtocolStream;
use futures::future;
use futures::future::Either;
use std::collections::HashMap;
@ -21,16 +20,17 @@ pub(crate) type ReplyListeners = Arc<Mutex<HashMap<u64, oneshot::Sender<Event>>>
/// ```rust
/// use bromine::prelude::*;
///
/// async fn my_callback<S: AsyncProtocolStream>(ctx: &Context<S>, _event: Event) -> IPCResult<()> {
/// async fn my_callback(ctx: &Context, _event: Event) -> IPCResult<()> {
/// // use the emitter on the context object to emit events
/// // inside callbacks
/// ctx.emitter.emit("ping", ()).await?;
/// Ok(())
/// }
/// ```
pub struct Context<S: AsyncProtocolStream> {
#[derive(Clone)]
pub struct Context {
/// The event emitter
pub emitter: StreamEmitter<S>,
pub emitter: StreamEmitter,
/// Field to store additional context data
pub data: Arc<RwLock<TypeMap>>,
@ -42,27 +42,9 @@ pub struct Context<S: AsyncProtocolStream> {
reply_timeout: Duration,
}
impl<S> Clone for Context<S>
where
S: AsyncProtocolStream,
{
fn clone(&self) -> Self {
Self {
emitter: self.emitter.clone(),
data: Arc::clone(&self.data),
stop_sender: Arc::clone(&self.stop_sender),
reply_listeners: Arc::clone(&self.reply_listeners),
reply_timeout: self.reply_timeout.clone(),
}
}
}
impl<P> Context<P>
where
P: AsyncProtocolStream,
{
impl Context {
pub(crate) fn new(
emitter: StreamEmitter<P>,
emitter: StreamEmitter,
data: Arc<RwLock<TypeMap>>,
stop_sender: Option<Sender<()>>,
reply_listeners: ReplyListeners,
@ -121,19 +103,8 @@ where
}
}
pub struct PooledContext<S: AsyncProtocolStream> {
contexts: Vec<PoolGuard<Context<S>>>,
}
impl<S> Clone for PooledContext<S>
where
S: AsyncProtocolStream,
{
fn clone(&self) -> Self {
Self {
contexts: self.contexts.clone(),
}
}
pub struct PooledContext {
contexts: Vec<PoolGuard<Context>>,
}
pub struct PoolGuard<T>
@ -217,12 +188,9 @@ where
}
}
impl<P> PooledContext<P>
where
P: AsyncProtocolStream,
{
impl PooledContext {
/// Creates a new pooled context from a list of contexts
pub(crate) fn new(contexts: Vec<Context<P>>) -> Self {
pub(crate) fn new(contexts: Vec<Context>) -> Self {
Self {
contexts: contexts.into_iter().map(PoolGuard::new).collect(),
}
@ -231,7 +199,7 @@ where
/// Acquires a context from the pool
/// It always chooses the one that is used the least
#[tracing::instrument(level = "trace", skip_all)]
pub fn acquire(&self) -> PoolGuard<Context<P>> {
pub fn acquire(&self) -> PoolGuard<Context> {
self.contexts
.iter()
.min_by_key(|c| c.count())

@ -14,10 +14,10 @@ pub mod stream_emitter;
/// Handles listening to a connection and triggering the corresponding event functions
async fn handle_connection<S: 'static + AsyncProtocolStream>(
namespaces: Arc<HashMap<String, Namespace<S>>>,
handler: Arc<EventHandler<S>>,
namespaces: Arc<HashMap<String, Namespace>>,
handler: Arc<EventHandler>,
mut read_half: S::OwnedSplitReadHalf,
ctx: Context<S>,
ctx: Context,
) {
while let Ok(event) = Event::from_async_read(&mut read_half).await {
tracing::trace!(
@ -52,11 +52,7 @@ async fn handle_connection<S: 'static + AsyncProtocolStream>(
}
/// Handles a single event in a different tokio context
fn handle_event<S: 'static + AsyncProtocolStream>(
ctx: Context<S>,
handler: Arc<EventHandler<S>>,
event: Event,
) {
fn handle_event(ctx: Context, handler: Arc<EventHandler>, event: Event) {
tokio::spawn(async move {
let id = event.id();
if let Err(e) = handler.handle_event(&ctx, event).await {

@ -14,21 +14,21 @@ use typemap_rev::TypeMap;
/// The IPC Server listening for connections.
/// Use the [IPCBuilder](crate::builder::IPCBuilder) to create a server.
/// Usually one does not need to use the IPCServer object directly.
pub struct IPCServer<L: AsyncStreamProtocolListener> {
pub(crate) handler: EventHandler<L::Stream>,
pub(crate) namespaces: HashMap<String, Namespace<L::Stream>>,
pub struct IPCServer {
pub(crate) handler: EventHandler,
pub(crate) namespaces: HashMap<String, Namespace>,
pub(crate) data: TypeMap,
pub(crate) timeout: Duration,
}
impl<L> IPCServer<L>
where
L: AsyncStreamProtocolListener,
{
impl IPCServer {
/// Starts the IPC Server.
/// Invoked by [IPCBuilder::build_server](crate::builder::IPCBuilder::build_server)
#[tracing::instrument(skip(self))]
pub async fn start(self, address: L::AddressType) -> Result<()> {
pub async fn start<L: AsyncStreamProtocolListener>(
self,
address: L::AddressType,
) -> Result<()> {
let listener = L::protocol_bind(address.clone()).await?;
let handler = Arc::new(self.handler);
let namespaces = Arc::new(self.namespaces);
@ -44,7 +44,7 @@ where
tokio::spawn(async move {
let (read_half, write_half) = stream.protocol_into_split();
let emitter = StreamEmitter::new(write_half);
let emitter = StreamEmitter::new::<L::Stream>(write_half);
let reply_listeners = ReplyListeners::default();
let ctx = Context::new(
StreamEmitter::clone(&emitter),
@ -54,7 +54,7 @@ where
timeout.into(),
);
handle_connection(namespaces, handler, read_half, ctx).await;
handle_connection::<L::Stream>(namespaces, handler, read_half, ctx).await;
});
}

@ -4,21 +4,19 @@ use crate::events::event::Event;
use crate::events::payload::EventSendPayload;
use crate::ipc::context::Context;
use crate::protocol::AsyncProtocolStream;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::Mutex;
/// An abstraction over the raw tokio tcp stream
/// An abstraction over any type that implements the AsyncProtocolStream trait
/// to emit events and share a connection across multiple
/// contexts.
pub struct StreamEmitter<S: AsyncProtocolStream> {
stream: Arc<Mutex<S::OwnedSplitWriteHalf>>,
pub struct StreamEmitter {
stream: Arc<Mutex<dyn AsyncWrite + Send + Sync + Unpin + 'static>>,
}
impl<S> Clone for StreamEmitter<S>
where
S: AsyncProtocolStream,
{
impl Clone for StreamEmitter {
fn clone(&self) -> Self {
Self {
stream: Arc::clone(&self.stream),
@ -26,11 +24,8 @@ where
}
}
impl<P> StreamEmitter<P>
where
P: AsyncProtocolStream,
{
pub fn new(stream: P::OwnedSplitWriteHalf) -> Self {
impl StreamEmitter {
pub fn new<P: AsyncProtocolStream + 'static>(stream: P::OwnedSplitWriteHalf) -> Self {
Self {
stream: Arc::new(Mutex::new(stream)),
}
@ -57,7 +52,7 @@ where
let event_bytes = event.into_bytes()?;
{
let mut stream = self.stream.lock().await;
(*stream).write_all(&event_bytes[..]).await?;
stream.deref_mut().write_all(&event_bytes[..]).await?;
tracing::trace!(bytes_len = event_bytes.len());
}
@ -130,7 +125,7 @@ impl EmitMetadata {
/// Waits for a reply to the given message.
#[tracing::instrument(skip(self, ctx), fields(self.message_id))]
pub async fn await_reply<P: AsyncProtocolStream>(&self, ctx: &Context<P>) -> Result<Event> {
pub async fn await_reply(&self, ctx: &Context) -> Result<Event> {
let reply = ctx.await_reply(self.message_id).await?;
if reply.name() == ERROR_EVENT_NAME {
Err(reply.data::<ErrorEventData>()?.into())

@ -6,7 +6,7 @@
//! use tokio::net::TcpListener;
//!
//! /// Callback ping function
//! async fn handle_ping<S: AsyncProtocolStream>(ctx: &Context<S>, event: Event) -> IPCResult<()> {
//! async fn handle_ping(ctx: &Context, event: Event) -> IPCResult<()> {
//! println!("Received ping event.");
//! ctx.emitter.emit_response(event.id(), "pong", ()).await?;
//!
@ -16,7 +16,7 @@
//! pub struct MyNamespace;
//!
//! impl MyNamespace {
//! async fn ping<S: AsyncProtocolStream>(_ctx: &Context<S>, _event: Event) -> IPCResult<()> {
//! async fn ping(_ctx: &Context, _event: Event) -> IPCResult<()> {
//! println!("My namespace received a ping");
//! Ok(())
//! }
@ -25,7 +25,7 @@
//! impl NamespaceProvider for MyNamespace {
//! fn name() -> &'static str {"my_namespace"}
//!
//! fn register<S: AsyncProtocolStream>(handler: &mut EventHandler<S>) {
//! fn register(handler: &mut EventHandler) {
//! events!(handler,
//! "ping" => Self::ping,
//! "ping2" => Self::ping

@ -10,7 +10,7 @@ use std::pin::Pin;
pub struct NamespaceBuilder<L: AsyncStreamProtocolListener> {
name: String,
handler: EventHandler<L::Stream>,
handler: EventHandler,
ipc_builder: IPCBuilder<L>,
}
@ -30,7 +30,7 @@ where
pub fn on<F: 'static>(mut self, event: &str, callback: F) -> Self
where
F: for<'a> Fn(
&'a Context<L::Stream>,
&'a Context,
Event,
) -> Pin<Box<(dyn Future<Output = Result<()>> + Send + 'a)>>
+ Send

@ -1,31 +1,15 @@
use crate::events::event_handler::EventHandler;
use crate::protocol::AsyncProtocolStream;
use std::sync::Arc;
#[derive(Debug)]
pub struct Namespace<S: AsyncProtocolStream> {
#[derive(Clone, Debug)]
pub struct Namespace {
name: String,
pub(crate) handler: Arc<EventHandler<S>>,
pub(crate) handler: Arc<EventHandler>,
}
impl<S> Clone for Namespace<S>
where
S: AsyncProtocolStream,
{
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
handler: Arc::clone(&self.handler),
}
}
}
impl<S> Namespace<S>
where
S: AsyncProtocolStream,
{
impl Namespace {
/// Creates a new namespace with an event handler to register event callbacks on
pub fn new<S2: ToString>(name: S2, handler: EventHandler<S>) -> Self {
pub fn new<S2: ToString>(name: S2, handler: EventHandler) -> Self {
Self {
name: name.to_string(),
handler: Arc::new(handler),

@ -1,16 +1,12 @@
use crate::events::event_handler::EventHandler;
use crate::namespace::Namespace;
use crate::protocol::AsyncProtocolStream;
pub trait NamespaceProvider {
fn name() -> &'static str;
fn register<S: AsyncProtocolStream>(handler: &mut EventHandler<S>);
fn register(handler: &mut EventHandler);
}
impl<S> Namespace<S>
where
S: AsyncProtocolStream,
{
impl Namespace {
pub fn from_provider<N: NamespaceProvider>() -> Self {
let name = N::name();
let mut handler = EventHandler::new();

@ -1,6 +1,5 @@
use super::utils::PingEventData;
use crate::prelude::*;
use crate::protocol::AsyncProtocolStream;
use crate::tests::utils::start_test_server;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
@ -10,7 +9,7 @@ use std::time::{Duration, SystemTime};
use tokio::net::TcpListener;
use typemap_rev::TypeMapKey;
async fn handle_ping_event<P: AsyncProtocolStream>(ctx: &Context<P>, e: Event) -> IPCResult<()> {
async fn handle_ping_event(ctx: &Context, e: Event) -> IPCResult<()> {
tokio::time::sleep(Duration::from_secs(1)).await;
let mut ping_data = e.data::<PingEventData>()?;
ping_data.time = SystemTime::now();
@ -91,7 +90,7 @@ fn get_builder_with_ping_namespace(address: &str) -> IPCBuilder<TcpListener> {
pub struct TestNamespace;
impl TestNamespace {
async fn ping<P: AsyncProtocolStream>(_c: &Context<P>, _e: Event) -> IPCResult<()> {
async fn ping(_c: &Context, _e: Event) -> IPCResult<()> {
println!("Ping received");
Ok(())
}
@ -102,7 +101,7 @@ impl NamespaceProvider for TestNamespace {
"Test"
}
fn register<S: AsyncProtocolStream>(handler: &mut EventHandler<S>) {
fn register(handler: &mut EventHandler) {
events!(handler,
"ping" => Self::ping,
"ping2" => Self::ping

@ -0,0 +1,151 @@
use async_trait::async_trait;
use bromine::error::Result;
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::io::Error;
use std::pin::Pin;
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::{
channel as async_channel, Receiver as AsyncReceiver, Sender as AsyncSender,
};
use tokio::sync::Mutex as AsyncMutex;
lazy_static! {
static ref LISTENERS_REF: Arc<AsyncMutex<HashMap<u8, AsyncSender<TestProtocolStream>>>> =
Arc::new(AsyncMutex::new(HashMap::new()));
}
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
let mut listeners = LISTENERS_REF.lock().await;
listeners.insert(number, sender);
}
async fn get_port(number: u8) -> Option<TestProtocolStream> {
let mut listeners = LISTENERS_REF.lock().await;
if let Some(sender) = listeners.get_mut(&number) {
let (s1, r1) = mpsc::channel();
let (s2, r2) = mpsc::channel();
let stream_1 = TestProtocolStream {
sender: Arc::new(Mutex::new(s1)),
receiver: Arc::new(Mutex::new(r2)),
};
let stream_2 = TestProtocolStream {
sender: Arc::new(Mutex::new(s2)),
receiver: Arc::new(Mutex::new(r1)),
};
sender.send(stream_2).await.ok();
Some(stream_1)
} else {
None
}
}
pub struct TestProtocolListener {
receiver: Arc<AsyncMutex<AsyncReceiver<TestProtocolStream>>>,
}
#[async_trait]
impl AsyncStreamProtocolListener for TestProtocolListener {
type AddressType = u8;
type RemoteAddressType = u8;
type Stream = TestProtocolStream;
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
let (sender, receiver) = async_channel(1);
add_port(address, sender).await;
Ok(Self {
receiver: Arc::new(AsyncMutex::new(receiver)),
})
}
async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> {
self.receiver
.lock()
.await
.recv()
.await
.map(|r| (r, 0u8))
.ok_or_else(|| IPCError::from("Failed to accept"))
}
}
#[derive(Clone)]
pub struct TestProtocolStream {
sender: Arc<Mutex<Sender<Vec<u8>>>>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl AsyncProtocolStreamSplit for TestProtocolStream {
type OwnedSplitReadHalf = Self;
type OwnedSplitWriteHalf = Self;
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
(self.clone(), self)
}
}
#[async_trait]
impl AsyncProtocolStream for TestProtocolStream {
type AddressType = u8;
async fn protocol_connect(address: Self::AddressType) -> Result<Self> {
get_port(address)
.await
.ok_or_else(|| IPCError::from("Failed to connect"))
}
}
impl AsyncRead for TestProtocolStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let receiver = self.receiver.lock().unwrap();
if let Ok(b) = receiver.recv() {
buf.put_slice(&b);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl AsyncWrite for TestProtocolStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
let sender = self.sender.lock().unwrap();
let vec_buf = buf.to_vec();
let buf_len = vec_buf.len();
sender.send(vec_buf).unwrap();
Poll::Ready(Ok(buf_len))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
}

@ -0,0 +1,45 @@
mod test_protocol;
use bromine::prelude::*;
use std::time::Duration;
use test_protocol::*;
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
Ok(())
}
async fn handle_pong_event(_ctx: &Context, _event: Event) -> IPCResult<()> {
Ok(())
}
fn get_builder(port: u8) -> IPCBuilder<TestProtocolListener> {
IPCBuilder::new()
.address(port)
.on(
"ping",
callback!(
ctx,
event,
async move { handle_ping_event(ctx, event).await }
),
)
.timeout(Duration::from_millis(100))
.on(
"pong",
callback!(
ctx,
event,
async move { handle_pong_event(ctx, event).await }
),
)
}
#[tokio::test]
async fn it_passes_events() {
tokio::task::spawn(async { get_builder(0).build_server().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let ctx = get_builder(0).build_client().await.unwrap();
ctx.emitter.emit("ping", ()).await.unwrap(); // todo fix reply deadlock
}
Loading…
Cancel
Save