diff --git a/.gitignore b/.gitignore index d3041f86..a51bbf61 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ tmp.* doc/tags _neovim _runtime +deps/ diff --git a/.luacheckrc b/.luacheckrc index 363b68da..4bda2d9c 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -7,6 +7,7 @@ self = false exclude_files = { "_neovim/*", "_runtime/*", + "deps/*", } -- Glorious list of warnings: https://luacheck.readthedocs.io/en/stable/warnings.html @@ -24,4 +25,5 @@ read_globals = { "it", "describe", "before_each", + "after_each", } diff --git a/CHANGELOG.md b/CHANGELOG.md index f11d996d..49c0bd05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Allow custom directory and ID logic for templates + +### Changed + +### Fixed + +- Fixed improper tmp-file creation during template tests + ## [v3.12.0](https://github.com/obsidian-nvim/obsidian.nvim/releases/tag/v3.12.0) - 2025-06-05 ### Added diff --git a/README.md b/README.md index 3b04f718..b3424f7e 100644 --- a/README.md +++ b/README.md @@ -449,6 +449,9 @@ require("obsidian").setup { time_format = "%H:%M", -- A map for custom variables, the key should be the variable and the value a function substitutions = {}, + + -- A map for configuring unique directories and paths for specific templates + customizations = {}, }, -- Sets how you follow URLs diff --git a/lua/obsidian/commands/new_from_template.lua b/lua/obsidian/commands/new_from_template.lua index 05f0f361..4666aeda 100644 --- a/lua/obsidian/commands/new_from_template.lua +++ b/lua/obsidian/commands/new_from_template.lua @@ -1,5 +1,6 @@ local util = require "obsidian.util" local log = require "obsidian.log" +local templates = require "obsidian.templates" ---@param client obsidian.Client ---@param data CommandArgs @@ -14,30 +15,40 @@ return function(client, data) local template = data.fargs[#data.fargs] if title ~= nil and template ~= nil then + templates.load_template_customizations(template, client) local note = client:create_note { title = title, template = template, no_write = false } client:open_note(note, { sync = true }) + templates.restore_client_configurations(client) return end - if title == nil or title == "" then - title = util.input("Enter title or path (optional): ", { completion = "file" }) - if not title then - log.warn "Aborted" - return - elseif title == "" then - title = nil - end - end - picker:find_templates { callback = function(name) + templates.load_template_customizations(name, client) + if title == nil or title == "" then + -- Must use pcall in case of KeyboardInterrupt + -- We cannot place `title` where `safe_title` is because it would be redeclaring it + local success, safe_title = pcall(util.input, "Enter title or path (optional): ", { completion = "file" }) + title = safe_title + if not success or not safe_title then + log.warn "Aborted" + templates.restore_client_configurations(client) + return + elseif safe_title == "" then + title = nil + end + end + if name == nil or name == "" then log.warn "Aborted" + templates.restore_client_configurations(client) return end + ---@type obsidian.Note local note = client:create_note { title = title, template = name, no_write = false } client:open_note(note, { sync = false }) + templates.restore_client_configurations(client) end, } end diff --git a/lua/obsidian/config.lua b/lua/obsidian/config.lua index 882bb0ac..0d251ddc 100644 --- a/lua/obsidian/config.lua +++ b/lua/obsidian/config.lua @@ -452,6 +452,7 @@ end ---@field date_format string|? ---@field time_format string|? ---@field substitutions table|? +---@field customizations table config.TemplateOpts = {} --- Get defaults. @@ -463,9 +464,15 @@ config.TemplateOpts.default = function() date_format = nil, time_format = nil, substitutions = {}, + customizations = {}, } end +---@class obsidian.config.CustomTemplateOpts +--- +---@field dir string|? +---@field note_id_func function|? + ---@class obsidian.config.UIOpts --- ---@field enable boolean diff --git a/lua/obsidian/templates.lua b/lua/obsidian/templates.lua index d7e9f3e7..d4c4287a 100644 --- a/lua/obsidian/templates.lua +++ b/lua/obsidian/templates.lua @@ -1,8 +1,10 @@ local Path = require "obsidian.path" local Note = require "obsidian.note" local util = require "obsidian.util" +local config = require "obsidian.config" local M = {} +local restore_client_key = "__restore_client_key" --- Resolve a template name to a path. --- @@ -197,4 +199,58 @@ M.insert_template = function(opts) return Note.from_buffer(buf) end +--- Loads the client with customizations for a template identified by `template_name` if present +--- +--- @param template_name string The template name +--- @param client obsidian.Client The client +M.load_template_customizations = function(template_name, client) + local success, template_path = pcall(resolve_template, template_name, client) + + if not success then + return + end + + --- @type obsidian.config.CustomTemplateOpts|? + local customization = nil + + -- Check if the configuration has a custom key for this template + for template_key, template_config in pairs(client.opts.templates.customizations) do + if template_key:lower() == template_path.stem:lower() then + customization = template_config + break + end + end + + if not customization then + return + end + + local restore_values = { + dir = client.opts.notes_subdir, + note_id_func = client.opts.note_id_func, + new_notes_location = client.opts.new_notes_location, + } + + client[restore_client_key] = restore_values + client.opts.notes_subdir = customization.dir + client.opts.note_id_func = customization.note_id_func + client.opts.new_notes_location = config.NewNotesLocation.notes_subdir + return nil +end + +--- Restores the client's configuration if saved previously (during `load_template_customizations`), does nothing otherwise +--- @param client obsidian.Client The client +M.restore_client_configurations = function(client) + --- @type unknown + local restore_values = rawget(client, restore_client_key) + + if not restore_values then + return + end + + client.opts.notes_subdir = restore_values.dir + client.opts.note_id_func = restore_values.note_id_func + client.opts.new_notes_location = restore_values.new_notes_location +end + return M diff --git a/lua/test_utils.lua b/lua/test_utils.lua new file mode 100644 index 00000000..d4ddf7ba --- /dev/null +++ b/lua/test_utils.lua @@ -0,0 +1,29 @@ +local M = {} + +local obsidian = require "obsidian" + +---Get a client in a temporary directory. +---@param testid string A unique ID for the tests which prevents directory collision +---@param templates_dir string The template directory +---@return obsidian.Client +M.get_tmp_client = function(testid, templates_dir) + templates_dir = templates_dir or "templates" + + local tmpdir = "tmp-vault-" .. testid + + vim.loop.fs_mkdir(tmpdir, 448) -- <-- octal representation of 700 (RWX) + vim.loop.fs_mkdir(tmpdir .. "/" .. templates_dir, 448) + + local client = obsidian.new_from_dir(tmpdir) + client.opts.templates.folder = "templates" + return client +end + +--- Clean up a client, removing any resources created during `get_tmp_client` +--- @param client obsidian.Client The Client +M.cleanup_tmp_client = function(client) + local path = client.dir:resolve() + vim.fs.rm(tostring(path), { recursive = true, force = true }) +end + +return M diff --git a/test/obsidian/commands/new_from_template_spec.lua b/test/obsidian/commands/new_from_template_spec.lua new file mode 100644 index 00000000..2ae0453b --- /dev/null +++ b/test/obsidian/commands/new_from_template_spec.lua @@ -0,0 +1,83 @@ +local get_tmp_client = require("test_utils").get_tmp_client +local cleanup_tmp_client = require("test_utils").cleanup_tmp_client +local new_from_template = require "obsidian.commands.new_from_template" +local spy = require "luassert.spy" + +local templates_dir = "templates" + +--- @type obsidian.config.CustomTemplateOpts +local zettelConfig = { + dir = "31 Atomic", + note_id_func = function(title) + return "31.01 - " .. title + end, +} + +describe("new_from_template", function() + --- @type obsidian.Client + local client = nil + + before_each(function() + client = get_tmp_client("new_from_template", templates_dir) + vim.loop.fs_mkdir(tostring(client.dir) .. "/" .. zettelConfig.dir, 448) + client.picker = function(_) + return { + find_templates = function(_, opts) + opts.callback() + end, + } + end + client.opts.templates.customizations = { + Zettel = zettelConfig, + } + client.opts.new_notes_location = "notes_subdir" + end) + + after_each(function() + if client then + cleanup_tmp_client(client) + end + end) + + it("should always try to load and restore template configurations", function() + -- Arrange + local templates = require "obsidian.templates" + client:create_note { dir = templates_dir, id = "zettel" } + spy.on(templates, "load_template_customizations") + spy.on(templates, "restore_client_configurations") + + -- Act + ---@diagnostic disable-next-line: missing-fields + new_from_template(client, { fargs = { "Special Title", "zettel" } }) + + -- Assert + assert.spy(templates.load_template_customizations).was.called() + assert.spy(templates.restore_client_configurations).was.called() + end) + + it("should place matched templates in the custom directory", function() + -- Arrange + client:create_note { dir = templates_dir, id = "zettel" } + local expectedDir = client.dir / zettelConfig.dir + local title = "The Big Bang" + local id = zettelConfig.note_id_func(title) + local expected = string.format("%s/%s.md", expectedDir, id) + client.picker = function(_) + return { + find_templates = function(_, opts) + opts.callback "zettel" + end, + } + end + + -- Act + + ---@diagnostic disable-next-line: missing-fields + new_from_template(client, { fargs = { title, "zettel" } }) + local f = io.open(expected, "r") + + -- Assert + assert.truthy(f) + io.close(f) + end) +end) diff --git a/tests/test_templates.lua b/tests/test_templates.lua index 0d3079a4..e9bd9ad9 100644 --- a/tests/test_templates.lua +++ b/tests/test_templates.lua @@ -1,48 +1,128 @@ -local obsidian = require "obsidian" -local Path = require "obsidian.path" +local get_tmp_client = require("test_utils").get_tmp_client +local cleanup_tmp_client = require("test_utils").cleanup_tmp_client local Note = require "obsidian.note" local templates = require "obsidian.templates" +local NewNotesLocation = require("obsidian.config").NewNotesLocation ----Get a client in a temporary directory. ---- ----@return obsidian.Client -local tmp_client = function() - -- This gives us a tmp file name, but we really want a directory. - -- So we delete that file immediately. - local tmpname = os.tmpname() - os.remove(tmpname) - - local dir = Path:new(tmpname .. "-obsidian/") - dir:mkdir { parents = true } - - return obsidian.new_from_dir(tostring(dir)) -end - -describe("templates.substitute_template_variables()", function() - it("should substitute built-in variables", function() - local client = tmp_client() - local text = "today is {{date}} and the title of the note is {{title}}" - MiniTest.expect.equality( - string.format("today is %s and the title of the note is %s", os.date "%Y-%m-%d", "FOO"), - templates.substitute_template_variables(text, client, Note.new("FOO", { "FOO" }, {})) - ) +local templates_dir = "templates" + +--- @type obsidian.config.CustomTemplateOpts +local zettelConfig = { + dir = "/custom/path/to/zettels", + note_id_func = function() + return "hummus" + end, +} + +describe("template", function() + --- @type obsidian.Client|? + local client = nil + + before_each(function() + client = get_tmp_client("templates", templates_dir) end) - it("should substitute custom variables", function() - local client = tmp_client() - client.opts.templates.substitutions = { - weekday = function() - return "Monday" - end, - } - local text = "today is {{weekday}}" - MiniTest.expect.equality( - "today is Monday", - templates.substitute_template_variables(text, client, Note.new("foo", {}, {})) - ) - - -- Make sure the client opts has not been modified. - MiniTest.expect.equality(1, vim.tbl_count(client.opts.templates.substitutions)) - MiniTest.expect.equality("function", type(client.opts.templates.substitutions.weekday)) + after_each(function() + if client then + cleanup_tmp_client(client) + end + end) + + describe("templates.load_template_customizations()", function() + before_each(function() + client.opts.templates.customizations = { + Zettel = zettelConfig, + } + end) + + after_each(function() + client.opts.templates.customizations = nil + end) + + it("should not load customizations for non-existant templates", function() + local old_id_func = client.opts.note_id_func + + templates.load_template_customizations("zettel", client) + + MiniTest.expect.equality(client.opts.notes_subdir, nil) + MiniTest.expect.no_equality(zettelConfig.note_id_func, client.opts.note_id_func) + MiniTest.expect.equality(old_id_func, client.opts.note_id_func) + end) + + it("should load customizations for existing template", function() + client:create_note { dir = templates_dir, id = "zettel" } + + templates.load_template_customizations("zettel", client) + + MiniTest.expect.equality(zettelConfig.dir, client.opts.notes_subdir) + MiniTest.expect.equality(zettelConfig.note_id_func, client.opts.note_id_func) + end) + + it("should load customizations case-insensitively if template exists", function() + client:create_note { dir = templates_dir, id = "zettel" } + + templates.load_template_customizations("zettel", client) + + MiniTest.expect.equality(zettelConfig.dir, client.opts.notes_subdir) + MiniTest.expect.equality(zettelConfig.note_id_func, client.opts.note_id_func) + end) + end) + + describe("templates.restore_client_configurations()", function() + it("should do nothing if no configuration is cached", function() + local old_id_func = client.opts.note_id_func + local notes_subdir = client.opts.notes_subdir + + templates.restore_client_configurations(client) + + MiniTest.expect.equality(old_id_func, client.opts.note_id_func) + MiniTest.expect.equality(notes_subdir, client.opts.notes_subdir) + end) + + it("should reload client configuration after successfully loading previously", function() + client:create_note { dir = templates_dir, id = "zettel" } + local old_id_func = client.opts.note_id_func + local notes_subdir = client.opts.notes_subdir + client.opts.templates.customizations = { + Zettel = zettelConfig, + } + + templates.load_template_customizations("zettel", client) + MiniTest.expect.equality(zettelConfig.dir, client.opts.notes_subdir) + MiniTest.expect.equality(NewNotesLocation.notes_subdir, client.opts.new_notes_location) + MiniTest.expect.equality(zettelConfig.note_id_func, client.opts.note_id_func) + + templates.restore_client_configurations(client) + + MiniTest.expect.equality(old_id_func, client.opts.note_id_func) + MiniTest.expect.equality(NewNotesLocation.current_dir, client.opts.new_notes_location) + MiniTest.expect.equality(notes_subdir, client.opts.notes_subdir) + end) + end) + + describe("templates.substitute_template_variables()", function() + it("should substitute built-in variables", function() + local text = "today is {{date}} and the title of the note is {{title}}" + MiniTest.expect.equality( + string.format("today is %s and the title of the note is %s", os.date "%Y-%m-%d", "FOO"), + templates.substitute_template_variables(text, client, Note.new("FOO", { "FOO" }, {})) + ) + end) + + it("should substitute custom variables", function() + client.opts.templates.substitutions = { + weekday = function() + return "Monday" + end, + } + local text = "today is {{weekday}}" + MiniTest.expect.equality( + "today is Monday", + templates.substitute_template_variables(text, client, Note.new("foo", {}, {})) + ) + + MiniTest.expect.equality(1, vim.tbl_count(client.opts.templates.substitutions)) + MiniTest.expect.equality("function", type(client.opts.templates.substitutions.weekday)) + end) end) end)