summaryrefslogtreecommitdiffstats
path: root/lua/conform/runner.lua
diff options
context:
space:
mode:
authorSteven Arcangeli <stevearc@stevearc.com>2023-08-28 18:28:07 -0700
committerSteven Arcangeli <stevearc@stevearc.com>2023-08-28 18:28:07 -0700
commitcddd536e087a9fd3d2c9ea5b0a44e46c7b4b54c2 (patch)
tree70f6868440596ae90b7f451379c3abfa5678849c /lua/conform/runner.lua
parent69c4495ab5ad3c07c3a4f3c2bcac2f070718b4cb (diff)
feat: range formatting
Should work the same as vim.lsp.buf.format(). Additionally, range formatting is supported for *any* formatter. If the formatter doesn't have native support for ranges, conform will do its best to only apply the diffs that affect that range.
Diffstat (limited to 'lua/conform/runner.lua')
-rw-r--r--lua/conform/runner.lua82
1 files changed, 57 insertions, 25 deletions
diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua
index 7ccaec6..843dae1 100644
--- a/lua/conform/runner.lua
+++ b/lua/conform/runner.lua
@@ -5,35 +5,51 @@ local uv = vim.uv or vim.loop
local M = {}
---@param ctx conform.Context
----@param config conform.StaticFormatterConfig
+---@param config conform.FormatterConfig
M.build_cmd = function(ctx, config)
local command = config.command
if type(command) == "function" then
command = command(ctx)
end
local cmd = { command }
- if config.args then
- local args = config.args
+ local args = {}
+ if ctx.range and config.range_args then
+ ---@cast ctx conform.RangeContext
+ args = config.range_args(ctx)
+ elseif config.args then
if type(config.args) == "function" then
args = config.args(ctx)
+ else
+ ---@diagnostic disable-next-line: cast-local-type
+ args = config.args
end
- ---@cast args string[]
- for _, v in ipairs(args) do
- if v == "$FILENAME" then
- v = ctx.filename
- elseif v == "$DIRNAME" then
- v = ctx.dirname
- end
- table.insert(cmd, v)
+ end
+
+ ---@diagnostic disable-next-line: param-type-mismatch
+ for _, v in ipairs(args) do
+ if v == "$FILENAME" then
+ v = ctx.filename
+ elseif v == "$DIRNAME" then
+ v = ctx.dirname
end
+ table.insert(cmd, v)
end
return cmd
end
+---@param range? conform.Range
+---@param start_a integer
+---@param end_a integer
+local function indices_in_range(range, start_a, end_a)
+ return not range or (start_a <= range["end"][1] and range["start"][1] <= end_a)
+end
+
---@param bufnr integer
---@param original_lines string[]
---@param new_lines string[]
-local function apply_format(bufnr, original_lines, new_lines)
+---@param range? conform.Range
+---@param only_apply_range boolean
+local function apply_format(bufnr, original_lines, new_lines, range, only_apply_range)
local original_text = table.concat(original_lines, "\n")
-- Trim off the final newline from the formatted text because that is baked in to
-- the vim lines representation
@@ -68,18 +84,21 @@ local function apply_format(bufnr, original_lines, new_lines)
count_b = count_b + 1
end
local replacement = util.tbl_slice(new_lines, start_b, start_b + count_b - 1)
- vim.api.nvim_buf_set_lines(bufnr, start_a - 1, start_a - 1 + count_a, true, replacement)
+ local end_a = start_a + count_a
+ if not only_apply_range or indices_in_range(range, start_a, end_a) then
+ vim.api.nvim_buf_set_lines(bufnr, start_a - 1, end_a - 1, true, replacement)
+ end
end
end
---@param bufnr integer
---@param formatter conform.FormatterInfo
+---@param config conform.FormatterConfig
+---@param ctx conform.Context
---@param input_lines string[]
---@param callback fun(err?: string, output?: string[])
----@return integer
-local function run_formatter(bufnr, formatter, input_lines, callback)
- local config = assert(require("conform").get_formatter_config(formatter.name))
- local ctx = M.build_context(bufnr, config)
+---@return integer job_id
+local function run_formatter(bufnr, formatter, config, ctx, input_lines, callback)
local cmd = M.build_cmd(ctx, config)
local cwd = nil
if config.cwd then
@@ -159,9 +178,10 @@ local function run_formatter(bufnr, formatter, input_lines, callback)
end
---@param bufnr integer
----@param config conform.StaticFormatterConfig
+---@param config conform.FormatterConfig
+---@param range? conform.Range
---@return conform.Context
-M.build_context = function(bufnr, config)
+M.build_context = function(bufnr, config, range)
if bufnr == 0 then
bufnr = vim.api.nvim_get_current_buf()
end
@@ -193,13 +213,15 @@ M.build_context = function(bufnr, config)
buf = bufnr,
filename = filename,
dirname = dirname,
+ range = range,
}
end
---@param bufnr integer
---@param formatters conform.FormatterInfo[]
+---@param range? conform.Range
---@param callback? fun(err?: string)
-M.format_async = function(bufnr, formatters, callback)
+M.format_async = function(bufnr, formatters, range, callback)
if bufnr == 0 then
bufnr = vim.api.nvim_get_current_buf()
end
@@ -207,6 +229,7 @@ M.format_async = function(bufnr, formatters, callback)
local changedtick = vim.b[bufnr].changedtick
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local input_lines = original_lines
+ local all_support_range_formatting = true
-- kill previous jobs for buffer
local prev_jid = vim.b[bufnr].conform_jid
@@ -221,7 +244,7 @@ M.format_async = function(bufnr, formatters, callback)
if not formatter then
-- discard formatting if buffer has changed
if vim.b[bufnr].changedtick == changedtick then
- apply_format(bufnr, original_lines, input_lines)
+ apply_format(bufnr, original_lines, input_lines, range, not all_support_range_formatting)
else
log.info(
"Async formatter discarding changes for %s: concurrent modification",
@@ -235,8 +258,10 @@ M.format_async = function(bufnr, formatters, callback)
end
idx = idx + 1
+ local config = assert(require("conform").get_formatter_config(formatter.name))
+ local ctx = M.build_context(bufnr, config, range)
local jid
- jid = run_formatter(bufnr, formatter, input_lines, function(err, output)
+ jid = run_formatter(bufnr, formatter, config, ctx, input_lines, function(err, output)
if err then
-- Only log the error if the job wasn't canceled
if vim.api.nvim_buf_is_valid(bufnr) and jid == vim.b[bufnr].conform_jid then
@@ -250,6 +275,7 @@ M.format_async = function(bufnr, formatters, callback)
input_lines = output
run_next_formatter()
end)
+ all_support_range_formatting = all_support_range_formatting and config.range_args ~= nil
end
run_next_formatter()
end
@@ -258,7 +284,8 @@ end
---@param formatters conform.FormatterInfo[]
---@param timeout_ms integer
---@param quiet boolean
-M.format_sync = function(bufnr, formatters, timeout_ms, quiet)
+---@param range? conform.Range
+M.format_sync = function(bufnr, formatters, timeout_ms, quiet, range)
if bufnr == 0 then
bufnr = vim.api.nvim_get_current_buf()
end
@@ -274,6 +301,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet)
end
end
+ local all_support_range_formatting = true
for _, formatter in ipairs(formatters) do
local remaining = timeout_ms - (uv.hrtime() / 1e6 - start)
if remaining <= 0 then
@@ -286,13 +314,16 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet)
end
local done = false
local result = nil
- run_formatter(bufnr, formatter, input_lines, function(err, output)
+ local config = assert(require("conform").get_formatter_config(formatter.name))
+ local ctx = M.build_context(bufnr, config, range)
+ local jid = run_formatter(bufnr, formatter, config, ctx, input_lines, function(err, output)
if err then
log.error(err)
end
done = true
result = output
end)
+ all_support_range_formatting = all_support_range_formatting and config.range_args ~= nil
local wait_result, wait_reason = vim.wait(remaining, function()
return done
@@ -306,6 +337,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet)
vim.notify(string.format("Formatter '%s' timed out", formatter.name), vim.log.levels.WARN)
end
end
+ vim.fn.jobstop(jid)
return
end
@@ -317,7 +349,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, quiet)
end
local final_result = input_lines
- apply_format(bufnr, original_lines, final_result)
+ apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting)
end
return M