Add support for pre and post execution hooks

main
trivernis 1 year ago
parent dbddc5879c
commit 198f0b65ce
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: DFFFCC2C7A02DB45

@ -3,6 +3,7 @@ use std::path::{Path, PathBuf};
use utils::logging::init_logger; use utils::logging::init_logger;
use utils::settings::get_settings; use utils::settings::get_settings;
use crate::server::endpoint::HookEndpoint;
use crate::server::HookServer; use crate::server::HookServer;
mod secret_validation; mod secret_validation;
@ -34,7 +35,10 @@ async fn init_and_start() {
for (name, endpoint) in &settings.endpoints { for (name, endpoint) in &settings.endpoints {
log::info!("Adding endpoint '{}' with path '{}'", name, &endpoint.path); log::info!("Adding endpoint '{}' with path '{}'", name, &endpoint.path);
server.add_hook(endpoint.path.clone(), endpoint.into()) server.add_hook(
endpoint.path.clone(),
HookEndpoint::from_config(name, &settings, &endpoint),
)
} }
let address = settings let address = settings

@ -1,118 +0,0 @@
use crate::secret_validation::SecretValidator;
use crate::server::command_template::CommandTemplate;
use crate::utils::error::{MultihookError, MultihookResult};
use crate::utils::settings::{EndpointSettings, SecretSettings};
use hyper::http::request::Parts;
use hyper::{Body, Request};
use serde_json::Value;
use std::fs::read_to_string;
use std::mem;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::process::Command;
use tokio::sync::Semaphore;
static MAX_CONCURRENCY: usize = 256;
#[derive(Clone)]
pub struct HookAction {
command: CommandTemplate,
parallel_lock: Arc<Semaphore>,
run_detached: bool,
secret: Option<SecretSettings>,
}
impl HookAction {
pub fn new<S: ToString>(
command: S,
parallel: bool,
detached: bool,
secret: Option<SecretSettings>,
) -> Self {
let parallel_lock = if parallel {
Semaphore::new(MAX_CONCURRENCY)
} else {
Semaphore::new(1)
};
Self {
command: CommandTemplate::new(command),
parallel_lock: Arc::new(parallel_lock),
run_detached: detached,
secret,
}
}
pub async fn execute(&self, req: Request<Body>) -> MultihookResult<()> {
let (parts, body) = req.into_parts();
let body = hyper::body::to_bytes(body).await?.to_vec();
self.validate_secret(&parts, &body)?;
let body = String::from_utf8(body)?;
if self.run_detached {
tokio::spawn({
let action = self.clone();
async move {
if let Err(e) = action.execute_command(&body).await {
log::error!("Detached hook threw an error: {:?}", e);
}
}
});
Ok(())
} else {
self.execute_command(&body).await
}
}
fn validate_secret(&self, parts: &Parts, body: &Vec<u8>) -> MultihookResult<()> {
if let Some(secret) = &self.secret {
let validator = secret.format.validator();
if !validator.validate(&parts.headers, &body, &secret.value.as_bytes()) {
return Err(MultihookError::InvalidSecret);
}
}
Ok(())
}
async fn execute_command(&self, body: &str) -> MultihookResult<()> {
let json_body: Value = serde_json::from_str(body).unwrap_or_default();
let command = self.command.evaluate(&json_body);
log::debug!("Acquiring lock for parallel runs...");
let permit = self.parallel_lock.acquire().await.unwrap();
log::debug!("Lock acquired. Running command...");
let output = Command::new("sh")
.env("HOOK_BODY", body)
.arg("-c")
.arg(command)
.kill_on_drop(true)
.output()
.await?;
log::debug!("Command finished. Releasing parallel lock...");
mem::drop(permit);
let stderr = String::from_utf8_lossy(&output.stderr[..]);
let stdout = String::from_utf8_lossy(&output.stdout[..]);
log::debug!("Command output is: {}", stdout);
if stderr.len() > 0 {
log::error!("Errors occurred during command execution: {}", stderr);
}
Ok(())
}
}
impl From<&EndpointSettings> for HookAction {
fn from(endpoint: &EndpointSettings) -> Self {
let action = endpoint.action.clone();
let path = PathBuf::from(&action);
let contents = read_to_string(path).unwrap_or(action);
Self::new(
contents,
endpoint.allow_parallel,
endpoint.run_detached,
endpoint.secret.clone(),
)
}
}

@ -0,0 +1,64 @@
use crate::utils::error::{MultihookError, MultihookResult};
use self::template::ActionTemplate;
use std::{collections::HashMap, sync::Arc};
use tokio::{process::Command, sync::Semaphore};
mod template;
static MAX_CONCURRENCY: usize = 256;
#[derive(Clone)]
pub struct Action {
template: ActionTemplate,
semaphore: Arc<Semaphore>,
}
impl Action {
/// Creates a new command that also checks for parallel runs
pub fn new<S: Into<String>>(command: S, allow_parallel: bool) -> Self {
let semaphore = if allow_parallel {
Semaphore::new(MAX_CONCURRENCY)
} else {
Semaphore::new(1)
};
Self {
template: ActionTemplate::new(command.into()),
semaphore: Arc::new(semaphore),
}
}
/// Executes the action
pub async fn run(
&self,
body: &serde_json::Value,
env: &HashMap<&str, String>,
) -> MultihookResult<()> {
let command = self.template.evaluate(&body);
log::debug!("Acquiring lock for parallel runs...");
let permit = self.semaphore.acquire().await.unwrap();
log::debug!("Lock acquired. Running command...");
std::mem::drop(permit);
let output = Command::new("sh")
.envs(env)
.arg("-c")
.arg(command)
.kill_on_drop(true)
.output()
.await?;
log::debug!("Command finished. Releasing parallel lock...");
let stderr = String::from_utf8_lossy(&output.stderr[..]);
let stdout = String::from_utf8_lossy(&output.stdout[..]);
log::debug!("Command output is: {}", stdout);
if stderr.len() > 0 {
log::error!("Errors occurred during command execution: {}", stderr);
Err(MultihookError::ActionError(stderr.into_owned()))
} else {
Ok(())
}
}
}

@ -4,12 +4,12 @@ use regex::{Match, Regex};
use serde_json::Value; use serde_json::Value;
#[derive(Clone)] #[derive(Clone)]
pub struct CommandTemplate { pub struct ActionTemplate {
src: String, src: String,
matches: Vec<(usize, usize)>, matches: Vec<(usize, usize)>,
} }
impl CommandTemplate { impl ActionTemplate {
pub fn new<S: ToString>(command: S) -> Self { pub fn new<S: ToString>(command: S) -> Self {
lazy_static! { lazy_static! {
static ref PLACEHOLDER_REGEX: Regex = Regex::new(r"\{\{(.*?)\}\}").unwrap(); static ref PLACEHOLDER_REGEX: Regex = Regex::new(r"\{\{(.*?)\}\}").unwrap();

@ -0,0 +1,159 @@
use std::collections::HashMap;
use crate::secret_validation::SecretValidator;
use crate::utils::error::{LogErr, MultihookError, MultihookResult};
use crate::utils::settings::{EndpointSettings, SecretSettings, Settings};
use hyper::http::request::Parts;
use hyper::{Body, Request};
use serde_json::Value;
use super::action::Action;
#[derive(Clone)]
pub struct HookEndpoint {
name: String,
action: Action,
global_hooks: ActionHooks,
hooks: ActionHooks,
run_detached: bool,
secret: Option<SecretSettings>,
}
#[derive(Clone, Default)]
struct ActionHooks {
pre: Option<Action>,
post: Option<Action>,
error: Option<Action>,
}
impl HookEndpoint {
pub fn from_config<S: Into<String>>(
name: S,
global: &Settings,
endpoint: &EndpointSettings,
) -> Self {
let global_hooks = global
.hooks
.as_ref()
.map(|hooks_cfg| ActionHooks {
pre: hooks_cfg.pre_action.as_ref().map(|a| Action::new(a, true)),
post: hooks_cfg.post_action.as_ref().map(|a| Action::new(a, true)),
error: hooks_cfg.err_action.as_ref().map(|a| Action::new(a, true)),
})
.unwrap_or_default();
let hooks = endpoint
.hooks
.as_ref()
.map(|hooks_cfg| ActionHooks {
pre: hooks_cfg
.pre_action
.as_ref()
.map(|a| Action::new(a, endpoint.allow_parallel)),
post: hooks_cfg
.post_action
.as_ref()
.map(|a| Action::new(a, endpoint.allow_parallel)),
error: hooks_cfg
.err_action
.as_ref()
.map(|a| Action::new(a, endpoint.allow_parallel)),
})
.unwrap_or_default();
Self {
name: name.into(),
action: Action::new(&endpoint.action, endpoint.allow_parallel),
run_detached: endpoint.run_detached,
secret: endpoint.secret.clone(),
global_hooks,
hooks,
}
}
pub async fn execute(&self, req: Request<Body>) -> MultihookResult<()> {
let (parts, body) = req.into_parts();
let body = hyper::body::to_bytes(body).await?.to_vec();
self.validate_secret(&parts, &body)?;
let body = String::from_utf8(body)?;
if self.run_detached {
tokio::spawn({
let action = self.clone();
async move {
if let Err(e) = action.execute_command(&body).await {
log::error!("Detached hook threw an error: {:?}", e);
}
}
});
Ok(())
} else {
self.execute_command(&body).await
}
}
fn validate_secret(&self, parts: &Parts, body: &Vec<u8>) -> MultihookResult<()> {
if let Some(secret) = &self.secret {
let validator = secret.format.validator();
if !validator.validate(&parts.headers, &body, &secret.value.as_bytes()) {
return Err(MultihookError::InvalidSecret);
}
}
Ok(())
}
async fn execute_command(&self, body: &str) -> MultihookResult<()> {
let json_body: Value = serde_json::from_str(body).unwrap_or_default();
let mut env = HashMap::new();
env.insert("HOOK_NAME", self.name.to_owned());
env.insert("HOOK_BODY", body.to_string());
if let Some(global_pre) = &self.global_hooks.pre {
global_pre
.run(&json_body, &env)
.await
.log_err("Global Pre-Hook failed {e}");
}
if let Some(pre_hook) = &self.hooks.pre {
pre_hook
.run(&json_body, &env)
.await
.log_err("Endpoint Pre-Hook failed {e}");
}
if let Err(e) = self.action.run(&json_body, &env).await {
env.insert("HOOK_ERROR", format!("{e}"));
if let Some(global_err_action) = &self.global_hooks.error {
global_err_action
.run(&json_body, &env)
.await
.log_err("Global Error-Hook failed {e}");
}
if let Some(err_hook) = &self.hooks.error {
err_hook
.run(&json_body, &env)
.await
.log_err("Endpoint Error-Hook failed");
}
Err(e)
} else {
if let Some(global_post_hook) = &self.global_hooks.post {
global_post_hook
.run(&json_body, &env)
.await
.log_err("Global Post-Hook failed");
}
if let Some(post_hook) = &self.hooks.post {
post_hook
.run(&json_body, &env)
.await
.log_err("Endpoint Post-Hook failed")
}
Ok(())
}
}
}

@ -2,13 +2,13 @@ use std::sync::Arc;
use hyper::{Body, Method, Response}; use hyper::{Body, Method, Response};
use action::HookAction; use endpoint::HookEndpoint;
use crate::server::http::{HTTPCallback, HTTPServer}; use crate::server::http::{HTTPCallback, HTTPServer};
use crate::utils::error::MultihookResult; use crate::utils::error::MultihookResult;
pub mod action; pub mod action;
pub mod command_template; pub mod endpoint;
mod http; mod http;
pub struct HookServer { pub struct HookServer {
@ -22,7 +22,7 @@ impl HookServer {
} }
} }
pub fn add_hook(&mut self, point: String, action: HookAction) { pub fn add_hook(&mut self, point: String, action: HookEndpoint) {
let action = Arc::new(action); let action = Arc::new(action);
let cb = HTTPCallback::new({ let cb = HTTPCallback::new({
@ -34,6 +34,7 @@ impl HookServer {
log::debug!("Executing hook {}", point); log::debug!("Executing hook {}", point);
action.execute(req).await?; action.execute(req).await?;
log::debug!("Hook {} executed", point); log::debug!("Hook {} executed", point);
Ok(Response::new(Body::from(format!( Ok(Response::new(Body::from(format!(
"Hook '{}' executed.", "Hook '{}' executed.",
point point

@ -25,4 +25,22 @@ pub enum MultihookError {
#[error("Secret validation failed.")] #[error("Secret validation failed.")]
InvalidSecret, InvalidSecret,
#[error(transparent)]
JsonError(#[from] serde_json::Error),
#[error("Action failed: {0}")]
ActionError(String),
}
pub trait LogErr {
fn log_err<S: AsRef<str>>(&self, template: S);
}
impl<T> LogErr for MultihookResult<T> {
fn log_err<S: AsRef<str>>(&self, message: S) {
if let Err(e) = self.as_ref() {
log::error!("{}: {}", message.as_ref(), e);
}
}
} }

@ -10,6 +10,7 @@ use std::path::{Path, PathBuf};
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Settings { pub struct Settings {
pub server: ServerSettings, pub server: ServerSettings,
pub hooks: Option<Hooks>,
pub endpoints: HashMap<String, EndpointSettings>, pub endpoints: HashMap<String, EndpointSettings>,
} }
@ -18,10 +19,18 @@ pub struct ServerSettings {
pub address: Option<String>, pub address: Option<String>,
} }
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
pub struct Hooks {
pub pre_action: Option<String>,
pub post_action: Option<String>,
pub err_action: Option<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct EndpointSettings { pub struct EndpointSettings {
pub path: String, pub path: String,
pub action: String, pub action: String,
pub hooks: Option<Hooks>,
#[serde(default)] #[serde(default)]
pub allow_parallel: bool, pub allow_parallel: bool,
#[serde(default)] #[serde(default)]
@ -40,6 +49,7 @@ impl Default for Settings {
Self { Self {
endpoints: HashMap::new(), endpoints: HashMap::new(),
server: ServerSettings { address: None }, server: ServerSettings { address: None },
hooks: None,
} }
} }
} }

Loading…
Cancel
Save