From 50bd608b8992a1bc84963ecefa893dd87eeb9b38 Mon Sep 17 00:00:00 2001 From: Trivernis Date: Tue, 12 Jan 2021 10:59:21 +0100 Subject: [PATCH] Change execute to allow bounded channels Signed-off-by: Trivernis --- Cargo.toml | 5 +++-- src/executor/mod.rs | 35 +++++++++++++++++++++++++++++++---- src/executor/ocl_stream.rs | 15 ++++++++++++++- src/lib.rs | 2 +- src/utils/result.rs | 34 ++++++++++------------------------ 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 74260cc..ba0ec54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ description = "OpenCL Stream execution framework" repository = "https://github.com/parallel-programming-hwr/ocl-stream-rs" license = "Apache-2.0" readme = "README.md" -version = "0.1.0" +version = "0.2.0" authors = ["Trivernis "] edition = "2018" @@ -14,4 +14,5 @@ edition = "2018" ocl = "0.19.3" num_cpus = "1.13.0" scheduled-thread-pool = "0.2.5" -crossbeam-channel = "0.5.0" \ No newline at end of file +crossbeam-channel = "0.5.0" +thiserror = "1.0.23" \ No newline at end of file diff --git a/src/executor/mod.rs b/src/executor/mod.rs index cdcd9c6..4a1f363 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -43,13 +43,42 @@ impl OCLStreamExecutor { self.concurrency = num_tasks; } + /// Replaces the used pool with a new one + pub fn set_pool(&mut self, pool: ScheduledThreadPool) { + self.pool = Arc::new(pool); + } + + /// Executes a closure in the ocl context with a bounded channel + pub fn execute_bounded(&self, size: usize, func: F) -> OCLStream + where + F: Fn(ExecutorContext) -> OCLStreamResult<()> + Send + Sync + 'static, + T: Send + Sync + 'static, + { + let (stream, sender) = ocl_stream::bounded(size); + self.execute(func, sender); + + stream + } + + /// Executes a closure in the ocl context with an unbounded channel + /// for streaming + pub fn execute_unbounded(&self, func: F) -> OCLStream + where + F: Fn(ExecutorContext) -> OCLStreamResult<()> + Send + Sync + 'static, + T: Send + Sync + 'static, + { + let (stream, sender) = ocl_stream::unbounded(); + self.execute(func, sender); + + stream + } + /// Executes a closure in the ocl context - pub fn execute(&self, func: F) -> OCLStream + fn execute(&self, func: F, sender: OCLStreamSender) where F: Fn(ExecutorContext) -> OCLStreamResult<()> + Send + Sync + 'static, T: Send + Sync + 'static, { - let (stream, sender) = ocl_stream::create(); let func = Arc::new(func); for task_id in 0..(self.concurrency) { @@ -64,8 +93,6 @@ impl OCLStreamExecutor { } }); } - - stream } /// Builds the executor context for the executor diff --git a/src/executor/ocl_stream.rs b/src/executor/ocl_stream.rs index 0c5ddcd..d9c57d5 100644 --- a/src/executor/ocl_stream.rs +++ b/src/executor/ocl_stream.rs @@ -10,7 +10,7 @@ use crate::utils::result::{OCLStreamError, OCLStreamResult}; /// Creates a new OCLStream with the corresponding sender /// to communicate between the scheduler thread and the receiver thread -pub fn create() -> (OCLStream, OCLStreamSender) +pub fn unbounded() -> (OCLStream, OCLStreamSender) where T: Send + Sync, { @@ -21,6 +21,19 @@ where (stream, sender) } +/// Creates a new OCLStream with the corresponding sender and a maximum capacity +/// to communicate between the scheduler thread and the receiver thread +pub fn bounded(size: usize) -> (OCLStream, OCLStreamSender) +where + T: Send + Sync, +{ + let (tx, rx) = crossbeam_channel::bounded(size); + let stream = OCLStream { rx }; + let sender = OCLStreamSender { tx }; + + (stream, sender) +} + /// Receiver for OCL Data #[derive(Clone, Debug)] pub struct OCLStream diff --git a/src/lib.rs b/src/lib.rs index fd3949e..4517b0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,7 +36,7 @@ mod tests { .unwrap(); let stream_executor = OCLStreamExecutor::new(pro_que); - let mut stream = stream_executor.execute(|ctx| { + let mut stream = stream_executor.execute_bounded(10, |ctx| { let pro_que = ctx.pro_que(); let tx = ctx.sender(); let input_buffer = pro_que.buffer_builder().len(100).fill_val(0u32).build()?; diff --git a/src/utils/result.rs b/src/utils/result.rs index 4e7329e..705af26 100644 --- a/src/utils/result.rs +++ b/src/utils/result.rs @@ -5,38 +5,24 @@ */ use crossbeam_channel::RecvError; -use std::error::Error; -use std::fmt::{self, Display, Formatter}; +use thiserror::Error; pub type OCLStreamResult = Result; -#[derive(Debug)] +#[derive(Error, Debug)] pub enum OCLStreamError { - OCLError(ocl::Error), - RecvError(RecvError), - SendError, -} + #[error("OpenCL Error {0}")] + OCLError(String), -impl Display for OCLStreamError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - OCLStreamError::OCLError(e) => write!(f, "OCL Error: {}", e), - OCLStreamError::RecvError(e) => write!(f, "Stream Receive Error: {}", e), - OCLStreamError::SendError => write!(f, "Stream Send Error"), - } - } -} + #[error("Stream Receive Error")] + RecvError(#[from] RecvError), -impl Error for OCLStreamError {} + #[error("Stream Send Error")] + SendError, +} impl From for OCLStreamError { fn from(e: ocl::Error) -> Self { - Self::OCLError(e) - } -} - -impl From for OCLStreamError { - fn from(e: RecvError) -> Self { - Self::RecvError(e) + Self::OCLError(format!("{}", e)) } }