From f245cca8ad42c9d344b53a18c3fc1a3c6724c2d4 Mon Sep 17 00:00:00 2001 From: Steven Arcangeli <506791+stevearc@users.noreply.github.com> Date: Tue, 26 Dec 2023 06:38:00 -0800 Subject: fix(injected): handle inline injections (#251) --- lua/conform/formatters/injected.lua | 118 +++++++++++++++++++++++++++++------- lua/conform/fs.lua | 20 ++++++ lua/conform/init.lua | 9 ++- lua/conform/runner.lua | 1 + 4 files changed, 125 insertions(+), 23 deletions(-) (limited to 'lua') diff --git a/lua/conform/formatters/injected.lua b/lua/conform/formatters/injected.lua index 77a9c0d..363889e 100644 --- a/lua/conform/formatters/injected.lua +++ b/lua/conform/formatters/injected.lua @@ -60,8 +60,31 @@ local function apply_indent(lines, indentation) end end +---@class LangRange +---@field [1] string language +---@field [2] integer start lnum +---@field [3] integer start col +---@field [4] integer end lnum +---@field [5] integer end col + +---@param ranges LangRange[] +---@param range LangRange +local function accum_range(ranges, range) + local last_range = ranges[#ranges] + if last_range then + if last_range[1] == range[1] and last_range[4] == range[2] and last_range[5] == range[3] then + last_range[4] = range[4] + last_range[5] = range[5] + return + end + end + table.insert(ranges, range) +end + ---@class (exact) conform.InjectedFormatterOptions ---@field ignore_errors boolean +---@field lang_to_ext table +---@field lang_to_formatters table ---@type conform.FileLuaFormatterConfig return { @@ -72,6 +95,26 @@ return { options = { -- Set to true to ignore errors ignore_errors = false, + -- Map of treesitter language to file extension + -- A temporary file name with this extension will be generated during formatting + -- because some formatters care about the filename. + lang_to_ext = { + bash = "sh", + c_sharp = "cs", + elixir = "exs", + javascript = "js", + julia = "jl", + latex = "tex", + markdown = "md", + python = "py", + ruby = "rb", + rust = "rs", + teal = "tl", + typescript = "ts", + }, + -- Map of treesitter language to formatters to use + -- (defaults to the value from formatters_by_ft) + lang_to_formatters = {}, }, condition = function(self, ctx) local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf) @@ -93,12 +136,20 @@ return { end ---@type conform.InjectedFormatterOptions local options = self.options + + ---@param lang string + ---@return nil|conform.FiletypeFormatter + local function get_formatters(lang) + return options.lang_to_formatters[lang] or conform.formatters_by_ft[lang] + end + --- Disable diagnostic to pass the typecheck github action --- This is available on nightly, but not on stable --- Stable doesn't have any parameters, so it's safe to always pass `false` ---@diagnostic disable-next-line: redundant-parameter parser:parse(false) local root_lang = parser:lang() + ---@type LangRange[] local regions = {} for _, tree in pairs(parser:trees()) do @@ -124,26 +175,26 @@ return { do ---@diagnostic disable-next-line: invisible local lang, combined, ranges = parser:_get_injection(match, metadata) - local has_formatters = conform.formatters_by_ft[lang] ~= nil - if lang and has_formatters and not combined and #ranges > 0 and lang ~= root_lang then - local start_lnum - local end_lnum - -- Merge all of the ranges into a single range + if + lang + and get_formatters(lang) ~= nil + and not combined + and #ranges > 0 + and lang ~= root_lang + then for _, range in ipairs(ranges) do - if not start_lnum or start_lnum > range[1] + 1 then - start_lnum = range[1] + 1 - end - if not end_lnum or end_lnum < range[4] then - end_lnum = range[4] - end - end - if in_range(ctx.range, start_lnum, end_lnum) then - table.insert(regions, { lang, start_lnum, end_lnum }) + accum_range(regions, { lang, range[1] + 1, range[2], range[4] + 1, range[5] }) end end end end + if ctx.range then + regions = vim.tbl_filter(function(region) + return in_range(ctx.range, region[2], region[4]) + end, regions) + end + -- Sort from largest start_lnum to smallest table.sort(regions, function(a, b) return a[2] > b[2] @@ -171,7 +222,11 @@ return { local formatted_lines = vim.deepcopy(lines) for _, replacement in ipairs(replacements) do - local start_lnum, end_lnum, new_lines = unpack(replacement) + local start_lnum, start_col, end_lnum, end_col, new_lines = unpack(replacement) + local prefix = formatted_lines[start_lnum]:sub(1, start_col) + local suffix = formatted_lines[end_lnum]:sub(end_col + 1) + new_lines[1] = prefix .. new_lines[1] + new_lines[#new_lines] = new_lines[#new_lines] .. suffix for _ = start_lnum, end_lnum do table.remove(formatted_lines, start_lnum) end @@ -184,12 +239,20 @@ return { local num_format = 0 local tmp_bufs = {} - local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines) + local formatter_cb = function(err, idx, region, input_lines, new_lines) if err then format_error = errors.coalesce(format_error, err) replacements[idx] = err else - replacements[idx] = { start_lnum, end_lnum, new_lines } + -- If the original lines started/ended with a newline, preserve that newline. + -- Many formatters will trim them, but they're important for the document structure. + if input_lines[1] == "" and new_lines[1] ~= "" then + table.insert(new_lines, 1, "") + end + if input_lines[#input_lines] == "" and new_lines[#new_lines] ~= "" then + table.insert(new_lines, "") + end + replacements[idx] = { region[2], region[3], region[4], region[5], new_lines } end num_format = num_format - 1 if num_format == 0 then @@ -200,14 +263,22 @@ return { end end local last_start_lnum = #lines + 1 - for _, region in ipairs(regions) do - local lang, start_lnum, end_lnum = unpack(region) + for i, region in ipairs(regions) do + local lang = region[1] + local start_lnum = region[2] + local start_col = region[3] + local end_lnum = region[4] + local end_col = region[5] -- Ignore regions that overlap (contain) other regions if end_lnum < last_start_lnum then num_format = num_format + 1 last_start_lnum = start_lnum local input_lines = util.tbl_slice(lines, start_lnum, end_lnum) - local ft_formatters = conform.formatters_by_ft[lang] + input_lines[#input_lines] = input_lines[#input_lines]:sub(1, end_col) + if start_col > 0 then + input_lines[1] = input_lines[1]:sub(start_col + 1) + end + local ft_formatters = assert(get_formatters(lang)) ---@type string[] local formatter_names if type(ft_formatters) == "function" then @@ -226,15 +297,18 @@ return { -- extension to determine a run mode (see https://github.com/stevearc/conform.nvim/issues/194) -- This is using the language name as the file extension, but that is a reasonable -- approximation for now. We can add special cases as the need arises. - local buf = vim.fn.bufadd(string.format("%s.%s", vim.api.nvim_buf_get_name(ctx.buf), lang)) + local extension = options.lang_to_ext[lang] or lang + local buf = + vim.fn.bufadd(string.format("%s.%d.%s", vim.api.nvim_buf_get_name(ctx.buf), i, extension)) -- Actually load the buffer to set the buffer context which is required by some formatters such as `filetype` vim.fn.bufload(buf) tmp_bufs[buf] = true local format_opts = { async = true, bufnr = buf, quiet = true } conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines) + log.trace("Injected %s:%d:%d formatted lines %s", lang, start_lnum, end_lnum, new_lines) -- Preserve indentation in case the code block is indented apply_indent(new_lines, indent) - formatter_cb(err, idx, start_lnum, end_lnum, new_lines) + vim.schedule_wrap(formatter_cb)(err, idx, region, input_lines, new_lines) end) end end diff --git a/lua/conform/fs.lua b/lua/conform/fs.lua index d303dbd..c33a2dc 100644 --- a/lua/conform/fs.lua +++ b/lua/conform/fs.lua @@ -15,4 +15,24 @@ M.join = function(...) return table.concat({ ... }, M.sep) end +---@param filepath string +---@return boolean +M.exists = function(filepath) + local stat = uv.fs_stat(filepath) + return stat ~= nil and stat.type ~= nil +end + +---@param filepath string +---@return string? +M.read_file = function(filepath) + if not M.exists(filepath) then + return nil + end + local fd = assert(uv.fs_open(filepath, "r", 420)) -- 0644 + local stat = assert(uv.fs_fstat(fd)) + local content = uv.fs_read(fd, stat.size) + uv.fs_close(fd) + return content +end + return M diff --git a/lua/conform/init.lua b/lua/conform/init.lua index 3824b4b..4fc35b5 100644 --- a/lua/conform/init.lua +++ b/lua/conform/init.lua @@ -36,6 +36,7 @@ local M = {} ---@field inherit? boolean ---@field command? string|fun(self: conform.FormatterConfig, ctx: conform.Context): string ---@field prepend_args? string|string[]|fun(self: conform.FormatterConfig, ctx: conform.Context): string|string[] +---@field format? fun(self: conform.LuaFormatterConfig, ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[])) Mutually exclusive with command ---@field options? table ---@class (exact) conform.FormatterMeta @@ -569,6 +570,12 @@ M.get_formatter_config = function(formatter, bufnr) if type(override) == "function" then override = override(bufnr) end + if override and override.command and override.format then + local msg = + string.format("Formatter '%s' cannot define both 'command' and 'format' function", formatter) + vim.notify_once(msg, vim.log.levels.ERROR) + return nil + end ---@type nil|conform.FormatterConfig local config = override @@ -581,7 +588,7 @@ M.get_formatter_config = function(formatter, bufnr) config = mod_config end elseif override then - if override.command then + if override.command or override.format then config = override else local msg = string.format( diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua index aa64e19..62d158e 100644 --- a/lua/conform/runner.lua +++ b/lua/conform/runner.lua @@ -342,6 +342,7 @@ local function run_formatter(bufnr, formatter, config, ctx, input_lines, opts, c end log.debug("%s exited with code %d", formatter.name, code) log.trace("Output lines: %s", output) + log.trace("%s stderr: %s", formatter.name, stderr) callback(nil, output) else log.info("%s exited with code %d", formatter.name, code) -- cgit v1.2.3-70-g09d2