From 198f0b65ce7f554c0504f5d11c76dbc364a2c49e Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 17 Jun 2023 15:12:06 +0200 Subject: [PATCH] Add support for pre and post execution hooks --- src/main.rs | 6 +- src/server/action.rs | 118 ------------- src/server/action/mod.rs | 64 +++++++ .../template.rs} | 4 +- src/server/endpoint.rs | 159 ++++++++++++++++++ src/server/mod.rs | 7 +- src/utils/error.rs | 18 ++ src/utils/settings.rs | 10 ++ 8 files changed, 262 insertions(+), 124 deletions(-) delete mode 100644 src/server/action.rs create mode 100644 src/server/action/mod.rs rename src/server/{command_template.rs => action/template.rs} (97%) create mode 100644 src/server/endpoint.rs diff --git a/src/main.rs b/src/main.rs index 529c771..402efe9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use utils::logging::init_logger; use utils::settings::get_settings; +use crate::server::endpoint::HookEndpoint; use crate::server::HookServer; mod secret_validation; @@ -34,7 +35,10 @@ async fn init_and_start() { for (name, endpoint) in &settings.endpoints { log::info!("Adding endpoint '{}' with path '{}'", name, &endpoint.path); - server.add_hook(endpoint.path.clone(), endpoint.into()) + server.add_hook( + endpoint.path.clone(), + HookEndpoint::from_config(name, &settings, &endpoint), + ) } let address = settings diff --git a/src/server/action.rs b/src/server/action.rs deleted file mode 100644 index 36efbcb..0000000 --- a/src/server/action.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::secret_validation::SecretValidator; -use crate::server::command_template::CommandTemplate; -use crate::utils::error::{MultihookError, MultihookResult}; -use crate::utils::settings::{EndpointSettings, SecretSettings}; -use hyper::http::request::Parts; -use hyper::{Body, Request}; -use serde_json::Value; -use std::fs::read_to_string; -use std::mem; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::process::Command; -use tokio::sync::Semaphore; - -static MAX_CONCURRENCY: usize = 256; - -#[derive(Clone)] -pub struct HookAction { - command: CommandTemplate, - parallel_lock: Arc, - run_detached: bool, - secret: Option, -} - -impl HookAction { - pub fn new( - command: S, - parallel: bool, - detached: bool, - secret: Option, - ) -> Self { - let parallel_lock = if parallel { - Semaphore::new(MAX_CONCURRENCY) - } else { - Semaphore::new(1) - }; - Self { - command: CommandTemplate::new(command), - parallel_lock: Arc::new(parallel_lock), - run_detached: detached, - secret, - } - } - - pub async fn execute(&self, req: Request) -> MultihookResult<()> { - let (parts, body) = req.into_parts(); - let body = hyper::body::to_bytes(body).await?.to_vec(); - - self.validate_secret(&parts, &body)?; - let body = String::from_utf8(body)?; - - if self.run_detached { - tokio::spawn({ - let action = self.clone(); - async move { - if let Err(e) = action.execute_command(&body).await { - log::error!("Detached hook threw an error: {:?}", e); - } - } - }); - - Ok(()) - } else { - self.execute_command(&body).await - } - } - - fn validate_secret(&self, parts: &Parts, body: &Vec) -> MultihookResult<()> { - if let Some(secret) = &self.secret { - let validator = secret.format.validator(); - if !validator.validate(&parts.headers, &body, &secret.value.as_bytes()) { - return Err(MultihookError::InvalidSecret); - } - } - Ok(()) - } - - async fn execute_command(&self, body: &str) -> MultihookResult<()> { - let json_body: Value = serde_json::from_str(body).unwrap_or_default(); - let command = self.command.evaluate(&json_body); - log::debug!("Acquiring lock for parallel runs..."); - let permit = self.parallel_lock.acquire().await.unwrap(); - log::debug!("Lock acquired. Running command..."); - - let output = Command::new("sh") - .env("HOOK_BODY", body) - .arg("-c") - .arg(command) - .kill_on_drop(true) - .output() - .await?; - log::debug!("Command finished. Releasing parallel lock..."); - mem::drop(permit); - let stderr = String::from_utf8_lossy(&output.stderr[..]); - let stdout = String::from_utf8_lossy(&output.stdout[..]); - log::debug!("Command output is: {}", stdout); - - if stderr.len() > 0 { - log::error!("Errors occurred during command execution: {}", stderr); - } - - Ok(()) - } -} - -impl From<&EndpointSettings> for HookAction { - fn from(endpoint: &EndpointSettings) -> Self { - let action = endpoint.action.clone(); - let path = PathBuf::from(&action); - let contents = read_to_string(path).unwrap_or(action); - Self::new( - contents, - endpoint.allow_parallel, - endpoint.run_detached, - endpoint.secret.clone(), - ) - } -} diff --git a/src/server/action/mod.rs b/src/server/action/mod.rs new file mode 100644 index 0000000..f18b25d --- /dev/null +++ b/src/server/action/mod.rs @@ -0,0 +1,64 @@ +use crate::utils::error::{MultihookError, MultihookResult}; + +use self::template::ActionTemplate; +use std::{collections::HashMap, sync::Arc}; +use tokio::{process::Command, sync::Semaphore}; + +mod template; + +static MAX_CONCURRENCY: usize = 256; + +#[derive(Clone)] +pub struct Action { + template: ActionTemplate, + semaphore: Arc, +} + +impl Action { + /// Creates a new command that also checks for parallel runs + pub fn new>(command: S, allow_parallel: bool) -> Self { + let semaphore = if allow_parallel { + Semaphore::new(MAX_CONCURRENCY) + } else { + Semaphore::new(1) + }; + + Self { + template: ActionTemplate::new(command.into()), + semaphore: Arc::new(semaphore), + } + } + + /// Executes the action + pub async fn run( + &self, + body: &serde_json::Value, + env: &HashMap<&str, String>, + ) -> MultihookResult<()> { + let command = self.template.evaluate(&body); + log::debug!("Acquiring lock for parallel runs..."); + let permit = self.semaphore.acquire().await.unwrap(); + log::debug!("Lock acquired. Running command..."); + std::mem::drop(permit); + + let output = Command::new("sh") + .envs(env) + .arg("-c") + .arg(command) + .kill_on_drop(true) + .output() + .await?; + log::debug!("Command finished. Releasing parallel lock..."); + + let stderr = String::from_utf8_lossy(&output.stderr[..]); + let stdout = String::from_utf8_lossy(&output.stdout[..]); + log::debug!("Command output is: {}", stdout); + + if stderr.len() > 0 { + log::error!("Errors occurred during command execution: {}", stderr); + Err(MultihookError::ActionError(stderr.into_owned())) + } else { + Ok(()) + } + } +} diff --git a/src/server/command_template.rs b/src/server/action/template.rs similarity index 97% rename from src/server/command_template.rs rename to src/server/action/template.rs index d69b11c..acec5c7 100644 --- a/src/server/command_template.rs +++ b/src/server/action/template.rs @@ -4,12 +4,12 @@ use regex::{Match, Regex}; use serde_json::Value; #[derive(Clone)] -pub struct CommandTemplate { +pub struct ActionTemplate { src: String, matches: Vec<(usize, usize)>, } -impl CommandTemplate { +impl ActionTemplate { pub fn new(command: S) -> Self { lazy_static! { static ref PLACEHOLDER_REGEX: Regex = Regex::new(r"\{\{(.*?)\}\}").unwrap(); diff --git a/src/server/endpoint.rs b/src/server/endpoint.rs new file mode 100644 index 0000000..db0d65c --- /dev/null +++ b/src/server/endpoint.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use crate::secret_validation::SecretValidator; +use crate::utils::error::{LogErr, MultihookError, MultihookResult}; +use crate::utils::settings::{EndpointSettings, SecretSettings, Settings}; +use hyper::http::request::Parts; +use hyper::{Body, Request}; +use serde_json::Value; + +use super::action::Action; + +#[derive(Clone)] +pub struct HookEndpoint { + name: String, + action: Action, + global_hooks: ActionHooks, + hooks: ActionHooks, + run_detached: bool, + secret: Option, +} + +#[derive(Clone, Default)] +struct ActionHooks { + pre: Option, + post: Option, + error: Option, +} + +impl HookEndpoint { + pub fn from_config>( + name: S, + global: &Settings, + endpoint: &EndpointSettings, + ) -> Self { + let global_hooks = global + .hooks + .as_ref() + .map(|hooks_cfg| ActionHooks { + pre: hooks_cfg.pre_action.as_ref().map(|a| Action::new(a, true)), + post: hooks_cfg.post_action.as_ref().map(|a| Action::new(a, true)), + error: hooks_cfg.err_action.as_ref().map(|a| Action::new(a, true)), + }) + .unwrap_or_default(); + + let hooks = endpoint + .hooks + .as_ref() + .map(|hooks_cfg| ActionHooks { + pre: hooks_cfg + .pre_action + .as_ref() + .map(|a| Action::new(a, endpoint.allow_parallel)), + post: hooks_cfg + .post_action + .as_ref() + .map(|a| Action::new(a, endpoint.allow_parallel)), + error: hooks_cfg + .err_action + .as_ref() + .map(|a| Action::new(a, endpoint.allow_parallel)), + }) + .unwrap_or_default(); + + Self { + name: name.into(), + action: Action::new(&endpoint.action, endpoint.allow_parallel), + run_detached: endpoint.run_detached, + secret: endpoint.secret.clone(), + global_hooks, + hooks, + } + } + + pub async fn execute(&self, req: Request) -> MultihookResult<()> { + let (parts, body) = req.into_parts(); + let body = hyper::body::to_bytes(body).await?.to_vec(); + + self.validate_secret(&parts, &body)?; + let body = String::from_utf8(body)?; + + if self.run_detached { + tokio::spawn({ + let action = self.clone(); + async move { + if let Err(e) = action.execute_command(&body).await { + log::error!("Detached hook threw an error: {:?}", e); + } + } + }); + + Ok(()) + } else { + self.execute_command(&body).await + } + } + + fn validate_secret(&self, parts: &Parts, body: &Vec) -> MultihookResult<()> { + if let Some(secret) = &self.secret { + let validator = secret.format.validator(); + if !validator.validate(&parts.headers, &body, &secret.value.as_bytes()) { + return Err(MultihookError::InvalidSecret); + } + } + Ok(()) + } + + async fn execute_command(&self, body: &str) -> MultihookResult<()> { + let json_body: Value = serde_json::from_str(body).unwrap_or_default(); + let mut env = HashMap::new(); + env.insert("HOOK_NAME", self.name.to_owned()); + env.insert("HOOK_BODY", body.to_string()); + + if let Some(global_pre) = &self.global_hooks.pre { + global_pre + .run(&json_body, &env) + .await + .log_err("Global Pre-Hook failed {e}"); + } + if let Some(pre_hook) = &self.hooks.pre { + pre_hook + .run(&json_body, &env) + .await + .log_err("Endpoint Pre-Hook failed {e}"); + } + if let Err(e) = self.action.run(&json_body, &env).await { + env.insert("HOOK_ERROR", format!("{e}")); + + if let Some(global_err_action) = &self.global_hooks.error { + global_err_action + .run(&json_body, &env) + .await + .log_err("Global Error-Hook failed {e}"); + } + if let Some(err_hook) = &self.hooks.error { + err_hook + .run(&json_body, &env) + .await + .log_err("Endpoint Error-Hook failed"); + } + + Err(e) + } else { + if let Some(global_post_hook) = &self.global_hooks.post { + global_post_hook + .run(&json_body, &env) + .await + .log_err("Global Post-Hook failed"); + } + if let Some(post_hook) = &self.hooks.post { + post_hook + .run(&json_body, &env) + .await + .log_err("Endpoint Post-Hook failed") + } + + Ok(()) + } + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 89dac9e..c5bf95f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,13 +2,13 @@ use std::sync::Arc; use hyper::{Body, Method, Response}; -use action::HookAction; +use endpoint::HookEndpoint; use crate::server::http::{HTTPCallback, HTTPServer}; use crate::utils::error::MultihookResult; pub mod action; -pub mod command_template; +pub mod endpoint; mod http; pub struct HookServer { @@ -22,7 +22,7 @@ impl HookServer { } } - pub fn add_hook(&mut self, point: String, action: HookAction) { + pub fn add_hook(&mut self, point: String, action: HookEndpoint) { let action = Arc::new(action); let cb = HTTPCallback::new({ @@ -34,6 +34,7 @@ impl HookServer { log::debug!("Executing hook {}", point); action.execute(req).await?; log::debug!("Hook {} executed", point); + Ok(Response::new(Body::from(format!( "Hook '{}' executed.", point diff --git a/src/utils/error.rs b/src/utils/error.rs index c78e5f6..56c3683 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -25,4 +25,22 @@ pub enum MultihookError { #[error("Secret validation failed.")] InvalidSecret, + + #[error(transparent)] + JsonError(#[from] serde_json::Error), + + #[error("Action failed: {0}")] + ActionError(String), +} + +pub trait LogErr { + fn log_err>(&self, template: S); +} + +impl LogErr for MultihookResult { + fn log_err>(&self, message: S) { + if let Err(e) = self.as_ref() { + log::error!("{}: {}", message.as_ref(), e); + } + } } diff --git a/src/utils/settings.rs b/src/utils/settings.rs index 9a9da7b..86fe9e9 100644 --- a/src/utils/settings.rs +++ b/src/utils/settings.rs @@ -10,6 +10,7 @@ use std::path::{Path, PathBuf}; #[derive(Serialize, Deserialize, Clone, Debug)] pub struct Settings { pub server: ServerSettings, + pub hooks: Option, pub endpoints: HashMap, } @@ -18,10 +19,18 @@ pub struct ServerSettings { pub address: Option, } +#[derive(Serialize, Deserialize, Default, Clone, Debug)] +pub struct Hooks { + pub pre_action: Option, + pub post_action: Option, + pub err_action: Option, +} + #[derive(Serialize, Deserialize, Clone, Debug)] pub struct EndpointSettings { pub path: String, pub action: String, + pub hooks: Option, #[serde(default)] pub allow_parallel: bool, #[serde(default)] @@ -40,6 +49,7 @@ impl Default for Settings { Self { endpoints: HashMap::new(), server: ServerSettings { address: None }, + hooks: None, } } }