Merge pull request #25 from Trivernis/develop

More serialization formats and change in feature names
pull/32/head
Julius Riegel 2 years ago committed by GitHub
commit ddf6e03ba2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,7 @@ jobs:
id: extract_branch
- name: Run benchmark
run: cargo bench -- --save-baseline ${{steps.extract_branch.outputs.branch}}
run: cargo bench --features serialize_rmp -- --save-baseline ${{steps.extract_branch.outputs.branch}}
- name: upload artifact
uses: actions/upload-artifact@v2

@ -33,4 +33,16 @@ jobs:
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
run: cargo test --verbose --tests
- name: Run rmp serialization tests
run: cargo test --verbose --all --features serialize_rmp
- name: Run bincode serialization tests
run: cargo test --verbose --all --features serialize_bincode
- name: Run postcard serialization tests
run: cargo test --verbose --all --features serialize_postcard
- name: Run json serialization tests
run: cargo test --verbose --all --features serialize_json

@ -23,4 +23,4 @@ jobs:
run: cargo login "$CRATES_IO_TOKEN"
- name: Publish to crates.io
run: cargo publish
run: cargo publish --all-features

245
Cargo.lock generated

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "aho-corasick"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
dependencies = [
"memchr",
]
[[package]]
name = "async-trait"
version = "0.1.51"
@ -13,6 +22,16 @@ dependencies = [
"syn",
]
[[package]]
name = "atomic-polyfill"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e686d748538a32325b28d6411dd8a939e7ad5128e5d0023cc4fd3573db456042"
dependencies = [
"critical-section",
"riscv-target",
]
[[package]]
name = "atty"
version = "0.2.14"
@ -30,6 +49,42 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "bare-metal"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5deb64efa5bd81e31fcd1938615a6d98c82eafcbcd787162b6f63b91d6bac5b3"
dependencies = [
"rustc_version 0.2.3",
]
[[package]]
name = "bare-metal"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8fe8f5a8a398345e52358e18ff07cc17a568fbca5c6f73873d3a62056309603"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bit_field"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcb6dd1c2376d2e096796e234a70e17e94cc2d5d54ff8ce42b28cef1d0d359a4"
[[package]]
name = "bitfield"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46afbd2983a5d5a7bd740ccb198caf5b82f45c40c09c0eed36052d91cb92e719"
[[package]]
name = "bitflags"
version = "1.3.2"
@ -38,15 +93,19 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bromine"
version = "0.13.0"
version = "0.14.0"
dependencies = [
"async-trait",
"bincode",
"byteorder",
"criterion",
"crossbeam-utils",
"futures",
"lazy_static",
"postcard",
"rmp-serde",
"serde",
"serde_json",
"thiserror",
"tokio",
"tracing",
@ -89,7 +148,7 @@ version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a"
dependencies = [
"rustc_version",
"rustc_version 0.4.0",
]
[[package]]
@ -109,6 +168,18 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "cortex-m"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ac919ef424449ec8c08d515590ce15d9262c0ca5f0da5b0c901e971a3b783b3"
dependencies = [
"bare-metal 0.2.5",
"bitfield",
"embedded-hal",
"volatile-register",
]
[[package]]
name = "criterion"
version = "0.3.5"
@ -147,6 +218,18 @@ dependencies = [
"itertools",
]
[[package]]
name = "critical-section"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01e191a5a6f6edad9b679777ef6b6c0f2bdd4a333f2ecb8f61c3e28109a03d70"
dependencies = [
"bare-metal 1.0.0",
"cfg-if",
"cortex-m",
"riscv",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.1"
@ -219,6 +302,16 @@ version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "embedded-hal"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e36cfb62ff156596c892272f3015ef952fe1525e85261fa3a7f327bd6b384ab9"
dependencies = [
"nb 0.1.3",
"void",
]
[[package]]
name = "futures"
version = "0.3.18"
@ -314,6 +407,28 @@ version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "heapless"
version = "0.7.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c1ad878e07405df82b695089e63d278244344f80e764074d0bdfe99b89460f3"
dependencies = [
"atomic-polyfill",
"hash32",
"serde",
"spin",
"stable_deref_trait",
]
[[package]]
name = "hermit-abi"
version = "0.1.19"
@ -359,6 +474,15 @@ version = "0.2.108"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8521a1b57e76b1ec69af7599e75e38e7b7fad6610f037db8c79b127201b5d119"
[[package]]
name = "lock_api"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109"
dependencies = [
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.14"
@ -405,6 +529,21 @@ dependencies = [
"winapi",
]
[[package]]
name = "nb"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f"
dependencies = [
"nb 1.0.0",
]
[[package]]
name = "nb"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "546c37ac5d9e56f55e73b677106873d9d9f5190605e41a856503623648488cae"
[[package]]
name = "ntapi"
version = "0.3.6"
@ -479,6 +618,23 @@ dependencies = [
"plotters-backend",
]
[[package]]
name = "postcard"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8863e251332eb18520388099b8b0acc4810ed6e602e3b6f674e8a46ba20e15c"
dependencies = [
"heapless",
"postcard-cobs",
"serde",
]
[[package]]
name = "postcard-cobs"
version = "0.1.5-pre"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c68cb38ed13fd7bc9dd5db8f165b7c8d9c1a315104083a2b10f11354c2af97f"
[[package]]
name = "proc-macro2"
version = "1.0.32"
@ -528,6 +684,8 @@ version = "1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
@ -543,6 +701,27 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "riscv"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6907ccdd7a31012b70faf2af85cd9e5ba97657cc3987c4f13f8e4d2c2a088aba"
dependencies = [
"bare-metal 1.0.0",
"bit_field",
"riscv-target",
]
[[package]]
name = "riscv-target"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88aa938cda42a0cf62a20cfe8d139ff1af20c2e681212b5b34adb5a58333f222"
dependencies = [
"lazy_static",
"regex",
]
[[package]]
name = "rmp"
version = "0.8.10"
@ -564,13 +743,22 @@ dependencies = [
"serde",
]
[[package]]
name = "rustc_version"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
dependencies = [
"semver 0.9.0",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
"semver 1.0.4",
]
[[package]]
@ -594,12 +782,27 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "semver"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
dependencies = [
"semver-parser",
]
[[package]]
name = "semver"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "568a8e6258aa33c13358f81fd834adb854c6f7c9468520910a9b1e8fac068012"
[[package]]
name = "semver-parser"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
[[package]]
name = "serde"
version = "1.0.130"
@ -647,6 +850,21 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
[[package]]
name = "spin"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "511254be0c5bcf062b019a6c89c01a664aa359ded62f78aa72c6fc137c0590e5"
dependencies = [
"lock_api",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "syn"
version = "1.0.82"
@ -775,6 +993,27 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "vcell"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77439c1b53d2303b20d9459b1ade71a83c716e3f9c34f3228c00e6f185d6c002"
[[package]]
name = "void"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
[[package]]
name = "volatile-register"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ee8f19f9d74293faf70901bc20ad067dc1ad390d2cbf1e3f75f721ffee908b6"
dependencies = [
"vcell",
]
[[package]]
name = "walkdir"
version = "2.3.2"

@ -1,6 +1,6 @@
[package]
name = "bromine"
version = "0.13.0"
version = "0.14.0"
authors = ["trivernis <trivernis@protonmail.com>"]
edition = "2018"
readme = "README.md"
@ -28,6 +28,8 @@ byteorder = "1.4.3"
async-trait = "0.1.51"
futures = "0.3.17"
rmp-serde = {version = "0.15.5", optional = true}
bincode = {version = "1.3.3", optional = true}
serde_json = {version = "1.0.72", optional = true}
[dependencies.serde]
optional = true
@ -38,8 +40,14 @@ features = []
version = "1.12.0"
features = ["net", "io-std", "io-util", "sync", "time"]
[dependencies.postcard]
version = "0.7.2"
optional = true
features = ["alloc"]
[dev-dependencies]
rmp-serde = "0.15.4"
crossbeam-utils = "0.8.5"
[dev-dependencies.serde]
version = "1.0.130"
@ -54,5 +62,9 @@ version = "1.12.0"
features = ["macros", "rt-multi-thread"]
[features]
default = ["messagepack"]
messagepack = ["serde", "rmp-serde"]
default = []
serialize = ["serde"]
serialize_rmp = ["serialize", "rmp-serde"]
serialize_bincode = ["serialize", "bincode"]
serialize_postcard = ["serialize", "postcard"]
serialize_json = ["serialize", "serde_json"]

@ -9,33 +9,29 @@ pub enum Error {
#[error(transparent)]
IoError(#[from] tokio::io::Error),
#[cfg(feature = "messagepack")]
#[error(transparent)]
Decode(#[from] rmp_serde::decode::Error),
#[cfg(feature = "messagepack")]
#[error(transparent)]
Encode(#[from] rmp_serde::encode::Error),
#[cfg(feature = "serialize")]
#[error("failed to serialize event: {0}")]
Serialization(#[from] crate::payload::SerializationError),
#[error("Build Error: {0}")]
#[error("build Error: {0}")]
BuildError(String),
#[error("{0}")]
Message(String),
#[error("Channel Error: {0}")]
#[error("channel Error: {0}")]
ReceiveError(#[from] oneshot::error::RecvError),
#[error("The received event was corrupted")]
#[error("the received event was corrupted")]
CorruptedEvent,
#[error("Send Error")]
#[error("send Error")]
SendError,
#[error("Error response: {0}")]
#[error("received error response: {0}")]
ErrorEvent(#[from] ErrorEventData),
#[error("Timed out")]
#[error("timed out")]
Timeout,
}

@ -7,6 +7,10 @@ pub mod event;
pub mod event_handler;
pub mod payload;
#[cfg(feature = "serialize")]
pub mod payload_serializer;
/// Generates a new event id
pub(crate) fn generate_event_id() -> u64 {
lazy_static::lazy_static! {

@ -2,6 +2,9 @@ use crate::prelude::IPCResult;
use byteorder::{BigEndian, ReadBytesExt};
use std::io::Read;
#[cfg(feature = "serialize")]
pub use super::payload_serializer::*;
/// Trait to convert event data into sending bytes
/// It is implemented for all types that implement Serialize
pub trait EventSendPayload {
@ -109,36 +112,3 @@ where
})
}
}
#[cfg(feature = "messagepack")]
mod rmp_impl {
use super::{EventReceivePayload, EventSendPayload};
use crate::prelude::IPCResult;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = rmp_serde::to_vec(&self)?;
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = rmp_serde::from_read(reader)?;
Ok(type_data)
}
}
}
#[cfg(feature = "messagepack")]
pub use rmp_impl::*;

@ -0,0 +1,23 @@
#[cfg(feature = "serialize_rmp")]
mod serialize_rmp;
#[cfg(feature = "serialize_rmp")]
pub use serialize_rmp::*;
#[cfg(feature = "serialize_bincode")]
mod serialize_bincode;
#[cfg(feature = "serialize_bincode")]
pub use serialize_bincode::*;
#[cfg(feature = "serialize_postcard")]
mod serialize_postcard;
#[cfg(feature = "serialize_postcard")]
pub use serialize_postcard::*;
#[cfg(feature = "serialize_json")]
mod serialize_json;
#[cfg(feature = "serialize_json")]
pub use serialize_json::*;

@ -0,0 +1,28 @@
use crate::payload::{EventReceivePayload, EventSendPayload};
use crate::prelude::IPCResult;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
pub type SerializationError = bincode::Error;
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = bincode::serialize(&self)?;
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = bincode::deserialize_from(reader)?;
Ok(type_data)
}
}

@ -0,0 +1,29 @@
use crate::payload::{EventReceivePayload, EventSendPayload};
use crate::prelude::IPCResult;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
pub type SerializationError = serde_json::Error;
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = serde_json::to_vec(&self)?;
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = serde_json::from_reader(reader)?;
Ok(type_data)
}
}

@ -0,0 +1,32 @@
use crate::payload::{EventReceivePayload, EventSendPayload};
use crate::prelude::IPCResult;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
pub type SerializationError = postcard::Error;
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = postcard::to_allocvec(&self)?.to_vec();
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(mut reader: R) -> IPCResult<Self> {
let mut buf = Vec::new();
// reading to end means reading the full size of the provided data
reader.read_to_end(&mut buf)?;
let type_data = postcard::from_bytes(&buf)?;
Ok(type_data)
}
}

@ -0,0 +1,48 @@
use crate::payload::{EventReceivePayload, EventSendPayload};
use crate::prelude::{IPCError, IPCResult};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io::Read;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SerializationError {
#[error("failed to serialize with rmp: {0}")]
Serialize(#[from] rmp_serde::encode::Error),
#[error("failed to deserialize with rmp: {0}")]
Deserialize(#[from] rmp_serde::decode::Error),
}
impl From<rmp_serde::decode::Error> for IPCError {
fn from(e: rmp_serde::decode::Error) -> Self {
IPCError::Serialization(SerializationError::Deserialize(e))
}
}
impl From<rmp_serde::encode::Error> for IPCError {
fn from(e: rmp_serde::encode::Error) -> Self {
IPCError::Serialization(SerializationError::Serialize(e))
}
}
impl<T> EventSendPayload for T
where
T: Serialize,
{
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let bytes = rmp_serde::to_vec(&self)?;
Ok(bytes)
}
}
impl<T> EventReceivePayload for T
where
T: DeserializeOwned,
{
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
let type_data = rmp_serde::from_read(reader)?;
Ok(type_data)
}
}

@ -101,8 +101,52 @@
//! # }
//! ```
#[cfg(test)]
mod tests;
#[cfg(all(
feature = "serialize",
not(any(
feature = "serialize_bincode",
feature = "serialize_rmp",
feature = "serialize_postcard",
feature = "serialize_json"
))
))]
compile_error!("Feature 'serialize' cannot be used by its own. Choose one of 'serialize_rmp', 'serialize_bincode', 'serialize_postcard' instead.");
#[cfg(any(
all(
feature = "serialize_rmp",
any(
feature = "serialize_postcard",
feature = "serialize_bincode",
feature = "serialize_json"
)
),
all(
feature = "serialize_bincode",
any(
feature = "serialize_rmp",
feature = "serialize_postcard",
feature = "serialize_json"
)
),
all(
feature = "serialize_postcard",
any(
feature = "serialize_rmp",
feature = "serialize_bincode",
feature = "serialize_json"
)
),
all(
feature = "serialize_json",
any(
feature = "serialize_rmp",
feature = "serialize_bincode",
feature = "serialize_postcard"
)
)
))]
compile_error!("You cannot use two serialize_* features at the same time");
pub mod error;
mod events;

@ -1,12 +0,0 @@
use crate::events::generate_event_id;
use std::collections::HashSet;
#[test]
fn event_ids_work() {
let mut ids = HashSet::new();
// simple collision test
for _ in 0..100000 {
assert!(ids.insert(generate_event_id()))
}
}

@ -1,235 +0,0 @@
use super::utils::PingEventData;
use crate::prelude::*;
use crate::tests::utils::start_test_server;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::net::TcpListener;
use typemap_rev::TypeMapKey;
async fn handle_ping_event(ctx: &Context, e: Event) -> IPCResult<()> {
tokio::time::sleep(Duration::from_secs(1)).await;
let mut ping_data = e.data::<PingEventData>()?;
ping_data.time = SystemTime::now();
ping_data.ttl -= 1;
if ping_data.ttl > 0 {
ctx.emitter.emit_response(e.id(), "pong", ping_data).await?;
}
Ok(())
}
fn get_builder_with_ping<L: AsyncStreamProtocolListener>(address: L::AddressType) -> IPCBuilder<L> {
IPCBuilder::new()
.on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e)))
.timeout(Duration::from_secs(10))
.address(address)
}
#[tokio::test]
async fn it_receives_tcp_events() {
let socket_address = "127.0.0.1:8281".to_socket_addrs().unwrap().next().unwrap();
it_receives_events::<TcpListener>(socket_address).await;
}
#[cfg(unix)]
#[tokio::test]
async fn it_receives_unix_socket_events() {
let socket_path = PathBuf::from("/tmp/test_socket");
if socket_path.exists() {
std::fs::remove_file(&socket_path).unwrap();
}
it_receives_events::<tokio::net::UnixListener>(socket_path).await;
}
async fn it_receives_events<L: 'static + AsyncStreamProtocolListener>(address: L::AddressType) {
let builder = get_builder_with_ping::<L>(address.clone());
let server_running = Arc::new(AtomicBool::new(false));
tokio::spawn({
let server_running = Arc::clone(&server_running);
let builder = get_builder_with_ping::<L>(address);
async move {
server_running.store(true, Ordering::SeqCst);
builder.build_server().await.unwrap();
}
});
while !server_running.load(Ordering::Relaxed) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
let pool = builder.build_pooled_client(8).await.unwrap();
let reply = pool
.acquire()
.emitter
.emit(
"ping",
PingEventData {
ttl: 16,
time: SystemTime::now(),
},
)
.await
.unwrap()
.await_reply(&pool.acquire())
.await
.unwrap();
assert_eq!(reply.name(), "pong");
}
fn get_builder_with_ping_namespace(address: &str) -> IPCBuilder<TcpListener> {
IPCBuilder::new()
.namespace("mainspace")
.on("ping", callback!(handle_ping_event))
.build()
.address(address.to_socket_addrs().unwrap().next().unwrap())
}
pub struct TestNamespace;
impl TestNamespace {
async fn ping(_c: &Context, _e: Event) -> IPCResult<()> {
println!("Ping received");
Ok(())
}
}
impl NamespaceProvider for TestNamespace {
fn name() -> &'static str {
"Test"
}
fn register(handler: &mut EventHandler) {
events!(handler,
"ping" => Self::ping,
"ping2" => Self::ping
);
}
}
#[tokio::test]
async fn it_receives_namespaced_events() {
let builder = get_builder_with_ping_namespace("127.0.0.1:8282");
let server_running = Arc::new(AtomicBool::new(false));
tokio::spawn({
let server_running = Arc::clone(&server_running);
let builder = get_builder_with_ping_namespace("127.0.0.1:8282");
async move {
server_running.store(true, Ordering::SeqCst);
builder.build_server().await.unwrap();
}
});
while !server_running.load(Ordering::Relaxed) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
let ctx = builder
.add_namespace(namespace!(TestNamespace))
.build_client()
.await
.unwrap();
let reply = ctx
.emitter
.emit_to(
"mainspace",
"ping",
PingEventData {
ttl: 16,
time: SystemTime::now(),
},
)
.await
.unwrap()
.await_reply(&ctx)
.await
.unwrap();
assert_eq!(reply.name(), "pong");
}
struct ErrorOccurredKey;
impl TypeMapKey for ErrorOccurredKey {
type Value = Arc<AtomicBool>;
}
fn get_builder_with_error_handling(
error_occurred: Arc<AtomicBool>,
address: &str,
) -> IPCBuilder<TcpListener> {
IPCBuilder::new()
.insert::<ErrorOccurredKey>(error_occurred)
.on("ping", move |_, _| {
Box::pin(async move { Err(IPCError::from("ERRROROROROR")) })
})
.on(
"error",
callback!(ctx, event, async move {
let error = event.data::<error_event::ErrorEventData>()?;
assert!(error.message.len() > 0);
assert_eq!(error.code, 500);
{
let data = ctx.data.read().await;
let error_occurred = data.get::<ErrorOccurredKey>().unwrap();
error_occurred.store(true, Ordering::SeqCst);
}
Ok(())
}),
)
.address(address.to_socket_addrs().unwrap().next().unwrap())
}
#[tokio::test]
async fn it_handles_errors() {
let error_occurred = Arc::new(AtomicBool::new(false));
let builder = get_builder_with_error_handling(Arc::clone(&error_occurred), "127.0.0.1:8283");
let server_running = Arc::new(AtomicBool::new(false));
tokio::spawn({
let server_running = Arc::clone(&server_running);
let error_occurred = Arc::clone(&error_occurred);
let builder = get_builder_with_error_handling(error_occurred, "127.0.0.1:8283");
async move {
server_running.store(true, Ordering::SeqCst);
builder.build_server().await.unwrap();
}
});
while !server_running.load(Ordering::Relaxed) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
let ctx = builder.build_client().await.unwrap();
ctx.emitter.emit("ping", ()).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert!(error_occurred.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_error_responses() {
static ADDRESS: &str = "127.0.0.1:8284";
start_test_server(ADDRESS).await.unwrap();
let ctx = IPCBuilder::<TcpListener>::new()
.address(ADDRESS.to_socket_addrs().unwrap().next().unwrap())
.build_client()
.await
.unwrap();
let reply = ctx
.emitter
.emit("ping", ())
.await
.unwrap()
.await_reply(&ctx)
.await
.unwrap();
assert_eq!(reply.name(), "pong");
let reply = ctx
.emitter
.emit("trigger_error", ())
.await
.unwrap()
.await_reply(&ctx)
.await;
assert!(reply.is_err());
}

@ -1,3 +0,0 @@
mod event_tests;
mod ipc_tests;
mod utils;

@ -1,37 +0,0 @@
use crate::error::Error;
use crate::IPCBuilder;
use serde::{Deserialize, Serialize};
use std::net::ToSocketAddrs;
use std::time::SystemTime;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct PingEventData {
pub time: SystemTime,
pub ttl: u8,
}
/// Starts a test IPC server
pub fn start_test_server(address: &'static str) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
tx.send(true).unwrap();
IPCBuilder::<TcpListener>::new()
.address(address.to_socket_addrs().unwrap().next().unwrap())
.on("ping", |ctx, event| {
Box::pin(async move {
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
Ok(())
})
})
.on("trigger_error", |_, _| {
Box::pin(async move { Err(Error::from("An error occurred.")) })
})
.build_server()
.await
.unwrap();
});
rx
}

@ -0,0 +1,141 @@
mod utils;
use crate::utils::start_server_and_client;
use bromine::prelude::*;
use payload_impl::SimplePayload;
use std::time::Duration;
use utils::call_counter::*;
use utils::get_free_port;
use utils::protocol::*;
#[tokio::test]
async fn it_sends_payloads() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
ctx.emitter
.emit(
"ping",
SimplePayload {
number: 0,
string: String::from("Hello World"),
},
)
.await
.unwrap();
// wait for the event to be handled
tokio::time::sleep(Duration::from_millis(10)).await;
let counters = get_counter_from_context(&ctx).await;
assert_eq!(counters.get("ping").await, 1);
assert_eq!(counters.get("pong").await, 1);
}
#[tokio::test]
async fn it_receives_payloads() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
let reply = ctx
.emitter
.emit(
"ping",
SimplePayload {
number: 0,
string: String::from("Hello World"),
},
)
.await
.unwrap()
.await_reply(&ctx)
.await
.unwrap();
let reply_payload = reply.data::<SimplePayload>().unwrap();
let counters = get_counter_from_context(&ctx).await;
assert_eq!(counters.get("ping").await, 1);
assert_eq!(reply_payload.string, String::from("Hello World"));
assert_eq!(reply_payload.number, 0);
}
async fn get_client_with_server(port: u8) -> Context {
start_server_and_client(move || get_builder(port)).await
}
fn get_builder(port: u8) -> IPCBuilder<TestProtocolListener> {
IPCBuilder::new()
.address(port)
.on("ping", callback!(handle_ping_event))
.on("pong", callback!(handle_pong_event))
.timeout(Duration::from_millis(10))
}
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
let payload = event.data::<SimplePayload>()?;
ctx.emitter
.emit_response(event.id(), "pong", payload)
.await?;
Ok(())
}
async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
let _payload = event.data::<SimplePayload>()?;
Ok(())
}
#[cfg(feature = "serialize")]
mod payload_impl {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct SimplePayload {
pub string: String,
pub number: u32,
}
}
#[cfg(not(feature = "serialize"))]
mod payload_impl {
use bromine::error::Result;
use bromine::payload::{EventReceivePayload, EventSendPayload};
use bromine::prelude::IPCResult;
use byteorder::{BigEndian, ReadBytesExt};
use std::io::Read;
pub struct SimplePayload {
pub string: String,
pub number: u32,
}
impl EventSendPayload for SimplePayload {
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
let mut buf = Vec::new();
let string_length = self.string.len() as u16;
let string_length_bytes = string_length.to_be_bytes();
buf.append(&mut string_length_bytes.to_vec());
let mut string_bytes = self.string.into_bytes();
buf.append(&mut string_bytes);
let num_bytes = self.number.to_be_bytes();
buf.append(&mut num_bytes.to_vec());
Ok(buf)
}
}
impl EventReceivePayload for SimplePayload {
fn from_payload_bytes<R: Read>(mut reader: R) -> Result<Self> {
let string_length = reader.read_u16::<BigEndian>()?;
let mut string_buf = vec![0u8; string_length as usize];
reader.read_exact(&mut string_buf)?;
let string = String::from_utf8(string_buf).unwrap();
let number = reader.read_u32::<BigEndian>()?;
Ok(Self { string, number })
}
}
}

@ -1,151 +0,0 @@
use async_trait::async_trait;
use bromine::error::Result;
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::io::Error;
use std::pin::Pin;
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::{
channel as async_channel, Receiver as AsyncReceiver, Sender as AsyncSender,
};
use tokio::sync::Mutex as AsyncMutex;
lazy_static! {
static ref LISTENERS_REF: Arc<AsyncMutex<HashMap<u8, AsyncSender<TestProtocolStream>>>> =
Arc::new(AsyncMutex::new(HashMap::new()));
}
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
let mut listeners = LISTENERS_REF.lock().await;
listeners.insert(number, sender);
}
async fn get_port(number: u8) -> Option<TestProtocolStream> {
let mut listeners = LISTENERS_REF.lock().await;
if let Some(sender) = listeners.get_mut(&number) {
let (s1, r1) = mpsc::channel();
let (s2, r2) = mpsc::channel();
let stream_1 = TestProtocolStream {
sender: Arc::new(Mutex::new(s1)),
receiver: Arc::new(Mutex::new(r2)),
};
let stream_2 = TestProtocolStream {
sender: Arc::new(Mutex::new(s2)),
receiver: Arc::new(Mutex::new(r1)),
};
sender.send(stream_2).await.ok();
Some(stream_1)
} else {
None
}
}
pub struct TestProtocolListener {
receiver: Arc<AsyncMutex<AsyncReceiver<TestProtocolStream>>>,
}
#[async_trait]
impl AsyncStreamProtocolListener for TestProtocolListener {
type AddressType = u8;
type RemoteAddressType = u8;
type Stream = TestProtocolStream;
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
let (sender, receiver) = async_channel(1);
add_port(address, sender).await;
Ok(Self {
receiver: Arc::new(AsyncMutex::new(receiver)),
})
}
async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> {
self.receiver
.lock()
.await
.recv()
.await
.map(|r| (r, 0u8))
.ok_or_else(|| IPCError::from("Failed to accept"))
}
}
#[derive(Clone)]
pub struct TestProtocolStream {
sender: Arc<Mutex<Sender<Vec<u8>>>>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl AsyncProtocolStreamSplit for TestProtocolStream {
type OwnedSplitReadHalf = Self;
type OwnedSplitWriteHalf = Self;
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
(self.clone(), self)
}
}
#[async_trait]
impl AsyncProtocolStream for TestProtocolStream {
type AddressType = u8;
async fn protocol_connect(address: Self::AddressType) -> Result<Self> {
get_port(address)
.await
.ok_or_else(|| IPCError::from("Failed to connect"))
}
}
impl AsyncRead for TestProtocolStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let receiver = self.receiver.lock().unwrap();
if let Ok(b) = receiver.recv() {
buf.put_slice(&b);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl AsyncWrite for TestProtocolStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
let sender = self.sender.lock().unwrap();
let vec_buf = buf.to_vec();
let buf_len = vec_buf.len();
sender.send(vec_buf).unwrap();
Poll::Ready(Ok(buf_len))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
}

@ -1,45 +1,158 @@
mod test_protocol;
mod utils;
use crate::utils::start_server_and_client;
use bromine::prelude::*;
use std::time::Duration;
use test_protocol::*;
use utils::call_counter::*;
use utils::get_free_port;
use utils::protocol::*;
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
/// Simple events are passed from the client to the server and responses
/// are emitted back to the client. Both will have received an event.
#[tokio::test]
async fn it_sends_events() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
ctx.emitter.emit("ping", EmptyPayload).await.unwrap();
Ok(())
// allow the event to be processed
tokio::time::sleep(Duration::from_millis(10)).await;
let counter = get_counter_from_context(&ctx).await;
assert_eq!(counter.get("ping").await, 1);
assert_eq!(counter.get("pong").await, 1);
}
async fn handle_pong_event(_ctx: &Context, _event: Event) -> IPCResult<()> {
Ok(())
/// Events sent to a specific namespace are handled by the namespace event handler
#[tokio::test]
async fn it_sends_namespaced_events() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
ctx.emitter
.emit_to("test", "ping", EmptyPayload)
.await
.unwrap();
ctx.emitter
.emit_to("test", "pong", EmptyPayload)
.await
.unwrap();
// allow the event to be processed
tokio::time::sleep(Duration::from_millis(10)).await;
let counter = get_counter_from_context(&ctx).await;
assert_eq!(counter.get("test:ping").await, 1);
assert_eq!(counter.get("test:pong").await, 1);
}
/// When awaiting the reply to an event the handler for the event doesn't get called.
/// Therefore we expect it to have a call count of 0.
#[tokio::test]
async fn it_receives_responses() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
let reply = ctx
.emitter
.emit("ping", EmptyPayload)
.await
.unwrap()
.await_reply(&ctx)
.await
.unwrap();
let counter = get_counter_from_context(&ctx).await;
assert_eq!(reply.name(), "pong");
assert_eq!(counter.get("ping").await, 1);
assert_eq!(counter.get("pong").await, 0);
}
/// When emitting errors from handlers the client should receive an error event
/// with the error that occurred on the server.
#[tokio::test]
async fn it_handles_errors() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
ctx.emitter
.emit("create_error", EmptyPayload)
.await
.unwrap();
// allow the event to be processed
tokio::time::sleep(Duration::from_millis(10)).await;
let counter = get_counter_from_context(&ctx).await;
assert_eq!(counter.get("error").await, 1);
}
/// When waiting for the reply to an event and an error occurs, the error should
/// bypass the handler and be passed as the Err variant on the await reply instead.
#[tokio::test]
async fn it_receives_error_responses() {
let port = get_free_port();
let ctx = get_client_with_server(port).await;
let result = ctx
.emitter
.emit("create_error", EmptyPayload)
.await
.unwrap()
.await_reply(&ctx)
.await;
let counter = get_counter_from_context(&ctx).await;
assert!(result.is_err());
assert_eq!(counter.get("error").await, 0);
}
async fn get_client_with_server(port: u8) -> Context {
start_server_and_client(move || get_builder(port)).await
}
fn get_builder(port: u8) -> IPCBuilder<TestProtocolListener> {
IPCBuilder::new()
.address(port)
.on(
"ping",
callback!(
ctx,
event,
async move { handle_ping_event(ctx, event).await }
),
)
.timeout(Duration::from_millis(100))
.on(
"pong",
callback!(
ctx,
event,
async move { handle_pong_event(ctx, event).await }
),
)
.on("ping", callback!(handle_ping_event))
.on("pong", callback!(handle_pong_event))
.on("create_error", callback!(handle_create_error_event))
.on("error", callback!(handle_error_event))
.namespace("test")
.on("ping", callback!(handle_ping_event))
.on("pong", callback!(handle_pong_event))
.on("create_error", callback!(handle_create_error_event))
.build()
}
#[tokio::test]
async fn it_passes_events() {
tokio::task::spawn(async { get_builder(0).build_server().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let ctx = get_builder(0).build_client().await.unwrap();
ctx.emitter.emit("ping", ()).await.unwrap(); // todo fix reply deadlock
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
ctx.emitter
.emit_response(event.id(), "pong", EmptyPayload)
.await?;
Ok(())
}
async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
Ok(())
}
async fn handle_create_error_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
Err(IPCError::from("Test Error"))
}
async fn handle_error_event(ctx: &Context, event: Event) -> IPCResult<()> {
increment_counter_for_event(ctx, &event).await;
Ok(())
}
pub struct EmptyPayload;
impl EventSendPayload for EmptyPayload {
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
Ok(vec![])
}
}

@ -0,0 +1,61 @@
use bromine::context::Context;
use bromine::event::Event;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use typemap_rev::TypeMapKey;
pub async fn get_counter_from_context(ctx: &Context) -> CallCounter {
let data = ctx.data.read().await;
data.get::<CallCounterKey>().unwrap().clone()
}
pub async fn increment_counter_for_event(ctx: &Context, event: &Event) {
let data = ctx.data.read().await;
let key_name = if let Some(namespace) = event.namespace() {
format!("{}:{}", namespace, event.name())
} else {
event.name().to_string()
};
data.get::<CallCounterKey>().unwrap().incr(&key_name).await;
}
pub struct CallCounterKey;
impl TypeMapKey for CallCounterKey {
type Value = CallCounter;
}
#[derive(Clone, Default, Debug)]
pub struct CallCounter {
inner: Arc<RwLock<HashMap<String, AtomicUsize>>>,
}
impl CallCounter {
pub async fn incr(&self, name: &str) {
{
let calls = self.inner.read().await;
if let Some(call) = calls.get(name) {
call.fetch_add(1, Ordering::Relaxed);
return;
}
}
{
let mut calls = self.inner.write().await;
calls.insert(name.to_string(), AtomicUsize::new(1));
}
}
pub async fn get(&self, name: &str) -> usize {
let calls = self.inner.read().await;
calls
.get(name)
.map(|n| n.load(Ordering::SeqCst))
.unwrap_or(0)
}
}

@ -0,0 +1,45 @@
use bromine::context::Context;
use bromine::protocol::AsyncStreamProtocolListener;
use bromine::IPCBuilder;
use call_counter::*;
use lazy_static::lazy_static;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use tokio::sync::oneshot::channel;
pub mod call_counter;
pub mod protocol;
pub fn get_free_port() -> u8 {
lazy_static! {
static ref PORT_COUNTER: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
}
PORT_COUNTER.fetch_add(1, Ordering::Relaxed)
}
pub async fn start_server_and_client<
F: Fn() -> IPCBuilder<L> + Send + Sync + 'static,
L: AsyncStreamProtocolListener,
>(
builder_fn: F,
) -> Context {
let counters = CallCounter::default();
let (sender, receiver) = channel::<()>();
let client_builder = builder_fn().insert::<CallCounterKey>(counters.clone());
tokio::task::spawn({
async move {
sender.send(()).unwrap();
builder_fn()
.insert::<CallCounterKey>(counters)
.build_server()
.await
.unwrap()
}
});
receiver.await.unwrap();
let ctx = client_builder.build_client().await.unwrap();
ctx
}

@ -0,0 +1,263 @@
use async_trait::async_trait;
use bromine::error::Result;
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
use lazy_static::lazy_static;
use std::cmp::min;
use std::collections::HashMap;
use std::future::Future;
use std::io::Error;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
lazy_static! {
static ref LISTENERS_REF: Arc<Mutex<HashMap<u8, Sender<TestProtocolStream>>>> =
Arc::new(Mutex::new(HashMap::new()));
}
/// Adds a channel that receives streams to handle
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
let mut listeners = LISTENERS_REF.lock().await;
listeners.insert(number, sender);
}
/// Returns a stream for the given port connecting with the server via channels
async fn get_port(number: u8) -> Option<TestProtocolStream> {
let mut listeners = LISTENERS_REF.lock().await;
if let Some(sender) = listeners.get_mut(&number) {
let (s1, r1) = channel(2);
let (s2, r2) = channel(2);
let stream_1 = TestProtocolStream {
sender: s1,
receiver: Arc::new(Mutex::new(r2)),
future: None,
remaining_buf: Default::default(),
};
let stream_2 = TestProtocolStream {
sender: s2,
receiver: Arc::new(Mutex::new(r1)),
future: None,
remaining_buf: Default::default(),
};
sender.send(stream_2).await.ok();
Some(stream_1)
} else {
None
}
}
pub struct TestProtocolListener {
receiver: Arc<Mutex<Receiver<TestProtocolStream>>>,
}
#[async_trait]
impl AsyncStreamProtocolListener for TestProtocolListener {
type AddressType = u8;
type RemoteAddressType = u8;
type Stream = TestProtocolStream;
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
let (sender, receiver) = channel(1);
add_port(address, sender).await;
Ok(Self {
receiver: Arc::new(Mutex::new(receiver)),
})
}
async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> {
self.receiver
.lock()
.await
.recv()
.await
.map(|r| (r, 0u8))
.ok_or_else(|| IPCError::from("Failed to accept"))
}
}
impl Clone for TestProtocolStream {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
receiver: Arc::clone(&self.receiver),
future: None,
remaining_buf: Default::default(),
}
}
}
pub struct TestProtocolStream {
sender: Sender<Vec<u8>>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
future: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
remaining_buf: Arc<Mutex<Vec<u8>>>,
}
impl TestProtocolStream {
/// Read from the receiver and remaining buffer
async fn read_from_receiver(
buf: &mut ReadBuf<'static>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
remaining_buf: Arc<Mutex<Vec<u8>>>,
) {
{
let mut remaining_buf = remaining_buf.lock().await;
if !remaining_buf.is_empty() {
if Self::read_from_remaining_buffer(buf, &mut remaining_buf).await {
return;
}
}
}
let mut receiver = receiver.lock().await;
if let Some(mut bytes) = receiver.recv().await {
let slice_len = min(bytes.len(), buf.capacity());
buf.put_slice(&bytes[0..slice_len]);
bytes.reverse();
bytes.truncate(bytes.len() - slice_len);
bytes.reverse();
let mut remaining_buf = remaining_buf.lock().await;
remaining_buf.append(&mut bytes);
}
}
/// Read from the remaining buffer returning a boolean if the
/// read buffer has been filled
async fn read_from_remaining_buffer(
buf: &mut ReadBuf<'static>,
remaining_buf: &mut Vec<u8>,
) -> bool {
if remaining_buf.len() < buf.capacity() {
buf.put_slice(&remaining_buf);
remaining_buf.clear();
false
} else if remaining_buf.len() == buf.capacity() {
buf.put_slice(&remaining_buf);
remaining_buf.clear();
true
} else {
let slice_len = buf.capacity();
let remaining_len = remaining_buf.len();
buf.put_slice(&remaining_buf[0..slice_len]);
remaining_buf.reverse();
remaining_buf.truncate(remaining_len - slice_len);
remaining_buf.reverse();
true
}
}
}
impl AsyncProtocolStreamSplit for TestProtocolStream {
type OwnedSplitReadHalf = Self;
type OwnedSplitWriteHalf = Self;
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
(self.clone(), self)
}
}
#[async_trait]
impl AsyncProtocolStream for TestProtocolStream {
type AddressType = u8;
async fn protocol_connect(address: Self::AddressType) -> Result<Self> {
get_port(address)
.await
.ok_or_else(|| IPCError::from("Failed to connect"))
}
}
impl AsyncRead for TestProtocolStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
unsafe {
// we need a mutable reference to access the inner future
let stream = self.get_unchecked_mut();
if stream.future.is_none() {
// we need to change the lifetime to be able to use the read buffer in the read future
let buf: &mut ReadBuf<'static> = mem::transmute(buf);
let receiver = Arc::clone(&stream.receiver);
let remaining_buf = Arc::clone(&stream.remaining_buf);
let future = TestProtocolStream::read_from_receiver(buf, receiver, remaining_buf);
stream.future = Some(Box::pin(future));
}
if let Some(future) = &mut stream.future {
match future.as_mut().poll(cx) {
Poll::Ready(_) => {
stream.future = None;
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Pending
}
}
}
}
impl AsyncWrite for TestProtocolStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
let write_len = buf.len();
unsafe {
// we need a mutable reference to access the inner future
let stream = self.get_unchecked_mut();
if stream.future.is_none() {
// we take ownership here so that we don't need to change lifetimes here
let buf = buf.to_vec();
let sender = stream.sender.clone();
let future = async move {
sender.send(buf).await.unwrap();
};
stream.future = Some(Box::pin(future));
}
if let Some(future) = &mut stream.future {
match future.as_mut().poll(cx) {
Poll::Ready(_) => {
stream.future = None;
Poll::Ready(Ok(write_len))
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Pending
}
}
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
Poll::Ready(Ok(()))
}
}
Loading…
Cancel
Save