From cddd536e087a9fd3d2c9ea5b0a44e46c7b4b54c2 Mon Sep 17 00:00:00 2001 From: Steven Arcangeli Date: Mon, 28 Aug 2023 18:28:07 -0700 Subject: feat: range formatting Should work the same as vim.lsp.buf.format(). Additionally, range formatting is supported for *any* formatter. If the formatter doesn't have native support for ranges, conform will do its best to only apply the diffs that affect that range. --- lua/conform/runner.lua | 82 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 25 deletions(-) (limited to 'lua/conform/runner.lua') diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua index 7ccaec6..843dae1 100644 --- a/lua/conform/runner.lua +++ b/lua/conform/runner.lua @@ -5,35 +5,51 @@ local uv = vim.uv or vim.loop local M = {} ---@param ctx conform.Context ----@param config conform.StaticFormatterConfig +---@param config conform.FormatterConfig M.build_cmd = function(ctx, config) local command = config.command if type(command) == "function" then command = command(ctx) end local cmd = { command } - if config.args then - local args = config.args + local args = {} + if ctx.range and config.range_args then + ---@cast ctx conform.RangeContext + args = config.range_args(ctx) + elseif config.args then if type(config.args) == "function" then args = config.args(ctx) + else + ---@diagnostic disable-next-line: cast-local-type + args = config.args end - ---@cast args string[] - for _, v in ipairs(args) do - if v == "$FILENAME" then - v = ctx.filename - elseif v == "$DIRNAME" then - v = ctx.dirname - end - table.insert(cmd, v) + end + + ---@diagnostic disable-next-line: param-type-mismatch + for _, v in ipairs(args) do + if v == "$FILENAME" then + v = ctx.filename + elseif v == "$DIRNAME" then + v = ctx.dirname end + table.insert(cmd, v) end return cmd end +---@param range? conform.Range +---@param start_a integer +---@param end_a integer +local function indices_in_range(range, start_a, end_a) + return not range or (start_a <= range["end"][1] and range["start"][1] <= end_a) +end + ---@param bufnr integer ---@param original_lines string[] ---@param new_lines string[] -local function apply_format(bufnr, original_lines, new_lines) +---@param range? conform.Range +---@param only_apply_range boolean +local function apply_format(bufnr, original_lines, new_lines, range, only_apply_range) local original_text = table.concat(original_lines, "\n") -- Trim off the final newline from the formatted text because that is baked in to -- the vim lines representation @@ -68,18 +84,21 @@ local function apply_format(bufnr, original_lines, new_lines) count_b = count_b + 1 end local replacement = util.tbl_slice(new_lines, start_b, start_b + count_b - 1) - vim.api.nvim_buf_set_lines(bufnr, start_a - 1, start_a - 1 + count_a, true, replacement) + local end_a = start_a + count_a + if not only_apply_range or indices_in_range(range, start_a, end_a) then + vim.api.nvim_buf_set_lines(bufnr, start_a - 1, end_a - 1, true, replacement) + end end end ---@param bufnr integer ---@param formatter conform.FormatterInfo +---@param config conform.FormatterConfig +---@param ctx conform.Context ---@param input_lines string[] ---@param callback fun(err?: string, output?: string[]) ----@return integer -local function run_formatter(bufnr, formatter, input_lines, callback) - local config = assert(require("conform").get_formatter_config(formatter.name)) - local ctx = M.build_context(bufnr, config) +---@return integer job_id +local function run_formatter(bufnr, formatter, config, ctx, input_lines, callback) local cmd = M.build_cmd(ctx, config) local cwd = nil if config.cwd then @@ -159,9 +178,10 @@ local function run_formatter(bufnr, formatter, input_lines, callback) end ---@param bufnr integer ----@param config conform.StaticFormatterConfig +---@param config conform.FormatterConfig +---@param range? conform.Range ---@return conform.Context -M.build_context = function(bufnr, config) +M.build_context = function(bufnr, config, range) if bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end @@ -193,13 +213,15 @@ M.build_context = function(bufnr, config) buf = bufnr, filename = filename, dirname = dirname, + range = range, } end ---@param bufnr integer ---@param formatters conform.FormatterInfo[] +---@param range? conform.Range ---@param callback? fun(err?: string) -M.format_async = function(bufnr, formatters, callback) +M.format_async = function(bufnr, formatters, range, callback) if bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end @@ -207,6 +229,7 @@ M.format_async = function(bufnr, formatters, callback) local changedtick = vim.b[bufnr].changedtick local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local input_lines = original_lines + local all_support_range_formatting = true -- kill previous jobs for buffer local prev_jid = vim.b[bufnr].conform_jid @@ -221,7 +244,7 @@ M.format_async = function(bufnr, formatters, callback) if not formatter then -- discard formatting if buffer has changed if vim.b[bufnr].changedtick == changedtick then - apply_format(bufnr, original_lines, input_lines) + apply_format(bufnr, original_lines, input_lines, range, not all_support_range_formatting) else log.info( "Async formatter discarding changes for %s: concurrent modification", @@ -235,8 +258,10 @@ M.format_async = function(bufnr, formatters, callback) end idx = idx + 1 + local config = assert(require("conform").get_formatter_config(formatter.name)) + local ctx = M.build_context(bufnr, config, range) local jid - jid = run_formatter(bufnr, formatter, input_lines, function(err, output) + jid = run_formatter(bufnr, formatter, config, ctx, input_lines, function(err, output) if err then -- Only log the error if the job wasn't canceled if vim.api.nvim_buf_is_valid(bufnr) and jid == vim.b[bufnr].conform_jid then @@ -250,6 +275,7 @@ M.format_async = function(bufnr, formatters, callback) input_lines = output run_next_formatter() end) + all_support_range_formatting = all_support_range_formatting and config.range_args ~= nil end run_next_formatter() end @@ -258,7 +284,8 @@ end ---@param formatters conform.FormatterInfo[] ---@param timeout_ms integer ---@param quiet boolean -M.format_sync = function(bufnr, formatters, timeout_ms, quiet) +---@param range? conform.Range +M.format_sync = function(bufnr, formatters, timeout_ms, quiet, range) if bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end @@ -274,6 +301,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet) end end + local all_support_range_formatting = true for _, formatter in ipairs(formatters) do local remaining = timeout_ms - (uv.hrtime() / 1e6 - start) if remaining <= 0 then @@ -286,13 +314,16 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet) end local done = false local result = nil - run_formatter(bufnr, formatter, input_lines, function(err, output) + local config = assert(require("conform").get_formatter_config(formatter.name)) + local ctx = M.build_context(bufnr, config, range) + local jid = run_formatter(bufnr, formatter, config, ctx, input_lines, function(err, output) if err then log.error(err) end done = true result = output end) + all_support_range_formatting = all_support_range_formatting and config.range_args ~= nil local wait_result, wait_reason = vim.wait(remaining, function() return done @@ -306,6 +337,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet) vim.notify(string.format("Formatter '%s' timed out", formatter.name), vim.log.levels.WARN) end end + vim.fn.jobstop(jid) return end @@ -317,7 +349,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet) end local final_result = input_lines - apply_format(bufnr, original_lines, final_result) + apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting) end return M -- cgit v1.2.3-70-g09d2