diff --git a/Cargo.lock b/Cargo.lock index 6a1c4a4..097fe10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1018,6 +1018,7 @@ dependencies = [ "handlebars_switch", "lazy_static", "log", + "merge-struct", "miette 5.10.0", "mlua", "pretty_env_logger", @@ -2700,6 +2701,16 @@ dependencies = [ "libc", ] +[[package]] +name = "merge-struct" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d82012d21e24135b839b6b9bebd622b7ff0cb40071498bc2d066d3a6d04dd4a" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "miette" version = "5.10.0" diff --git a/Cargo.toml b/Cargo.toml index 3294984..7144797 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ handlebars = "5.0.0" handlebars_switch = "0.6.0" lazy_static = "1.4.0" log = "0.4.20" +merge-struct = "0.1.0" miette = { version = "5.10.0", features = ["serde", "fancy"] } mlua = { version = "0.9.6", features = ["serialize", "luau", "vendored"] } pretty_env_logger = "0.5.0" diff --git a/src/config.rs b/src/config.rs index ca653fe..c78286a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,16 @@ -use std::{collections::HashMap, path::Path}; +use std::{collections::HashMap, fs, path::Path}; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, }; use miette::{Context, IntoDiagnostic, Result}; +use mlua::LuaSerdeExt; use serde::{Deserialize, Serialize}; use which::which; +use crate::{scripting::create_lua, utils::Describe}; + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct SiloConfig { /// Diff tool used to display file differences @@ -39,12 +42,58 @@ fn detect_difftool() -> String { /// and the `repo.local.toml` config file /// and environment variables prefixed with `SILO_`` pub fn read_config(repo: &Path) -> Result { - Figment::from(Serialized::defaults(SiloConfig::default())) - .merge(Toml::file(dirs::config_dir().unwrap().join("silo.toml"))) + let conf_dir = dirs::config_dir().unwrap(); + let default_config = conf_dir.join("silo.config.lua"); + let old_config = conf_dir.join("silo.toml"); + + if !default_config.exists() { + let mut lines = vec![ + "local silo = require 'silo'".to_owned(), + "local utils = require 'utils'".to_owned(), + "local config = silo.config".to_owned(), + ]; + if old_config.exists() { + lines.push("".to_owned()); + lines.push("-- merge with old toml config".to_owned()); + lines.push(format!( + "config = utils.merge(config, utils.load_toml {old_config:?})" + )); + } + lines.push("-- Changes can be added to the `config` object".to_owned()); + lines.push("".to_owned()); + lines.push("return config".to_owned()); + + fs::write(&default_config, lines.join("\n")).describe("Writing default config")? + } + + let mut builder = Figment::from(Serialized::defaults(SiloConfig::default())) + .merge(Toml::file(old_config)) .merge(Toml::file(repo.join("repo.toml"))) - .merge(Toml::file(repo.join("repo.local.toml"))) + .merge(Toml::file(repo.join("repo.local.toml"))); + + let repo_defaults = repo.join("silo.config.lua"); + + if repo_defaults.exists() { + builder = builder.merge(Serialized::globals(read_lua_config(&repo_defaults)?)) + } + + builder + .merge(Serialized::globals(read_lua_config(&default_config)?)) .merge(Env::prefixed("SILO_")) .extract() .into_diagnostic() .context("parsing config file") } + +fn read_lua_config(path: &Path) -> Result { + let lua = create_lua(&())?; + let result = lua + .load(path) + .eval() + .with_describe(|| format!("evaluating config script {path:?}"))?; + let cfg = lua + .from_value(result) + .describe("deserializing lua config value")?; + + Ok(cfg) +} diff --git a/src/scripting/mod.rs b/src/scripting/mod.rs index e33628d..05e1151 100644 --- a/src/scripting/mod.rs +++ b/src/scripting/mod.rs @@ -1,6 +1,7 @@ pub mod log_module; mod require; pub mod silo_module; +pub mod utils_module; use miette::Result; use mlua::{Lua, LuaSerdeExt}; diff --git a/src/scripting/require.rs b/src/scripting/require.rs index 22017fe..6a605d9 100644 --- a/src/scripting/require.rs +++ b/src/scripting/require.rs @@ -2,6 +2,7 @@ use mlua::{Lua, Result, Table}; use super::log_module::log_module; use super::silo_module::silo_module; +use super::utils_module::utils_module; pub fn register_require(lua: &Lua) -> Result<()> { let globals = lua.globals(); @@ -16,6 +17,7 @@ fn lua_require(lua: &Lua, module: String) -> Result> { match module.as_str() { "silo" => silo_module(lua), "log" => log_module(lua), + "utils" => utils_module(lua), _ => { let old_require: mlua::Function = lua.globals().get("old_require")?; old_require.call(module) diff --git a/src/scripting/silo_module.rs b/src/scripting/silo_module.rs index 75d8a95..c3be412 100644 --- a/src/scripting/silo_module.rs +++ b/src/scripting/silo_module.rs @@ -1,6 +1,6 @@ use mlua::{Lua, LuaSerdeExt, Result, Table}; -use crate::templating::ContextData; +use crate::{config::SiloConfig, templating::ContextData}; pub fn silo_module(lua: &Lua) -> Result { let silo_ctx = ContextData::default(); @@ -10,6 +10,7 @@ pub fn silo_module(lua: &Lua) -> Result
{ exports.set("flags", lua.to_value(&silo_ctx.flags)?)?; exports.set("system", lua.to_value(&silo_ctx.system)?)?; exports.set("usercfg", lua.globals().get::<_, mlua::Value>("silo_ctx")?)?; + exports.set("config", lua.to_value(&SiloConfig::default())?)?; Ok(exports) } diff --git a/src/scripting/utils_module.rs b/src/scripting/utils_module.rs new file mode 100644 index 0000000..4c04a1a --- /dev/null +++ b/src/scripting/utils_module.rs @@ -0,0 +1,27 @@ +use std::fs; + +use mlua::{Lua, LuaSerdeExt, Result, Table}; + +pub fn utils_module(lua: &Lua) -> Result
{ + let exports = lua.create_table()?; + + exports.set("merge", lua.create_function(lua_merge)?)?; + exports.set("load_toml", lua.create_function(lua_read_toml)?)?; + + Ok(exports) +} + +fn lua_merge<'a>(lua: &'a Lua, (a, b): (mlua::Value, mlua::Value)) -> Result> { + let val_a: serde_json::Value = lua.from_value(a)?; + let val_b: serde_json::Value = lua.from_value(b)?; + let merged = merge_struct::merge(&val_a, &val_b).map_err(mlua::Error::external)?; + + lua.to_value(&merged) +} + +fn lua_read_toml<'a>(lua: &'a Lua, path: String) -> Result> { + let contents = fs::read_to_string(path)?; + let toml_value: toml::Value = toml::from_str(&contents).map_err(mlua::Error::external)?; + + lua.to_value(&toml_value) +}