aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/conform/formatters/injected.lua118
-rw-r--r--lua/conform/fs.lua20
-rw-r--r--lua/conform/init.lua9
-rw-r--r--lua/conform/runner.lua1
4 files changed, 125 insertions, 23 deletions
diff --git a/lua/conform/formatters/injected.lua b/lua/conform/formatters/injected.lua
index 77a9c0d..363889e 100644
--- a/lua/conform/formatters/injected.lua
+++ b/lua/conform/formatters/injected.lua
@@ -60,8 +60,31 @@ local function apply_indent(lines, indentation)
end
end
+---@class LangRange
+---@field [1] string language
+---@field [2] integer start lnum
+---@field [3] integer start col
+---@field [4] integer end lnum
+---@field [5] integer end col
+
+---@param ranges LangRange[]
+---@param range LangRange
+local function accum_range(ranges, range)
+ local last_range = ranges[#ranges]
+ if last_range then
+ if last_range[1] == range[1] and last_range[4] == range[2] and last_range[5] == range[3] then
+ last_range[4] = range[4]
+ last_range[5] = range[5]
+ return
+ end
+ end
+ table.insert(ranges, range)
+end
+
---@class (exact) conform.InjectedFormatterOptions
---@field ignore_errors boolean
+---@field lang_to_ext table<string, string>
+---@field lang_to_formatters table<string, conform.FiletypeFormatter>
---@type conform.FileLuaFormatterConfig
return {
@@ -72,6 +95,26 @@ return {
options = {
-- Set to true to ignore errors
ignore_errors = false,
+ -- Map of treesitter language to file extension
+ -- A temporary file name with this extension will be generated during formatting
+ -- because some formatters care about the filename.
+ lang_to_ext = {
+ bash = "sh",
+ c_sharp = "cs",
+ elixir = "exs",
+ javascript = "js",
+ julia = "jl",
+ latex = "tex",
+ markdown = "md",
+ python = "py",
+ ruby = "rb",
+ rust = "rs",
+ teal = "tl",
+ typescript = "ts",
+ },
+ -- Map of treesitter language to formatters to use
+ -- (defaults to the value from formatters_by_ft)
+ lang_to_formatters = {},
},
condition = function(self, ctx)
local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
@@ -93,12 +136,20 @@ return {
end
---@type conform.InjectedFormatterOptions
local options = self.options
+
+ ---@param lang string
+ ---@return nil|conform.FiletypeFormatter
+ local function get_formatters(lang)
+ return options.lang_to_formatters[lang] or conform.formatters_by_ft[lang]
+ end
+
--- Disable diagnostic to pass the typecheck github action
--- This is available on nightly, but not on stable
--- Stable doesn't have any parameters, so it's safe to always pass `false`
---@diagnostic disable-next-line: redundant-parameter
parser:parse(false)
local root_lang = parser:lang()
+ ---@type LangRange[]
local regions = {}
for _, tree in pairs(parser:trees()) do
@@ -124,26 +175,26 @@ return {
do
---@diagnostic disable-next-line: invisible
local lang, combined, ranges = parser:_get_injection(match, metadata)
- local has_formatters = conform.formatters_by_ft[lang] ~= nil
- if lang and has_formatters and not combined and #ranges > 0 and lang ~= root_lang then
- local start_lnum
- local end_lnum
- -- Merge all of the ranges into a single range
+ if
+ lang
+ and get_formatters(lang) ~= nil
+ and not combined
+ and #ranges > 0
+ and lang ~= root_lang
+ then
for _, range in ipairs(ranges) do
- if not start_lnum or start_lnum > range[1] + 1 then
- start_lnum = range[1] + 1
- end
- if not end_lnum or end_lnum < range[4] then
- end_lnum = range[4]
- end
- end
- if in_range(ctx.range, start_lnum, end_lnum) then
- table.insert(regions, { lang, start_lnum, end_lnum })
+ accum_range(regions, { lang, range[1] + 1, range[2], range[4] + 1, range[5] })
end
end
end
end
+ if ctx.range then
+ regions = vim.tbl_filter(function(region)
+ return in_range(ctx.range, region[2], region[4])
+ end, regions)
+ end
+
-- Sort from largest start_lnum to smallest
table.sort(regions, function(a, b)
return a[2] > b[2]
@@ -171,7 +222,11 @@ return {
local formatted_lines = vim.deepcopy(lines)
for _, replacement in ipairs(replacements) do
- local start_lnum, end_lnum, new_lines = unpack(replacement)
+ local start_lnum, start_col, end_lnum, end_col, new_lines = unpack(replacement)
+ local prefix = formatted_lines[start_lnum]:sub(1, start_col)
+ local suffix = formatted_lines[end_lnum]:sub(end_col + 1)
+ new_lines[1] = prefix .. new_lines[1]
+ new_lines[#new_lines] = new_lines[#new_lines] .. suffix
for _ = start_lnum, end_lnum do
table.remove(formatted_lines, start_lnum)
end
@@ -184,12 +239,20 @@ return {
local num_format = 0
local tmp_bufs = {}
- local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
+ local formatter_cb = function(err, idx, region, input_lines, new_lines)
if err then
format_error = errors.coalesce(format_error, err)
replacements[idx] = err
else
- replacements[idx] = { start_lnum, end_lnum, new_lines }
+ -- If the original lines started/ended with a newline, preserve that newline.
+ -- Many formatters will trim them, but they're important for the document structure.
+ if input_lines[1] == "" and new_lines[1] ~= "" then
+ table.insert(new_lines, 1, "")
+ end
+ if input_lines[#input_lines] == "" and new_lines[#new_lines] ~= "" then
+ table.insert(new_lines, "")
+ end
+ replacements[idx] = { region[2], region[3], region[4], region[5], new_lines }
end
num_format = num_format - 1
if num_format == 0 then
@@ -200,14 +263,22 @@ return {
end
end
local last_start_lnum = #lines + 1
- for _, region in ipairs(regions) do
- local lang, start_lnum, end_lnum = unpack(region)
+ for i, region in ipairs(regions) do
+ local lang = region[1]
+ local start_lnum = region[2]
+ local start_col = region[3]
+ local end_lnum = region[4]
+ local end_col = region[5]
-- Ignore regions that overlap (contain) other regions
if end_lnum < last_start_lnum then
num_format = num_format + 1
last_start_lnum = start_lnum
local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
- local ft_formatters = conform.formatters_by_ft[lang]
+ input_lines[#input_lines] = input_lines[#input_lines]:sub(1, end_col)
+ if start_col > 0 then
+ input_lines[1] = input_lines[1]:sub(start_col + 1)
+ end
+ local ft_formatters = assert(get_formatters(lang))
---@type string[]
local formatter_names
if type(ft_formatters) == "function" then
@@ -226,15 +297,18 @@ return {
-- extension to determine a run mode (see https://github.com/stevearc/conform.nvim/issues/194)
-- This is using the language name as the file extension, but that is a reasonable
-- approximation for now. We can add special cases as the need arises.
- local buf = vim.fn.bufadd(string.format("%s.%s", vim.api.nvim_buf_get_name(ctx.buf), lang))
+ local extension = options.lang_to_ext[lang] or lang
+ local buf =
+ vim.fn.bufadd(string.format("%s.%d.%s", vim.api.nvim_buf_get_name(ctx.buf), i, extension))
-- Actually load the buffer to set the buffer context which is required by some formatters such as `filetype`
vim.fn.bufload(buf)
tmp_bufs[buf] = true
local format_opts = { async = true, bufnr = buf, quiet = true }
conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
+ log.trace("Injected %s:%d:%d formatted lines %s", lang, start_lnum, end_lnum, new_lines)
-- Preserve indentation in case the code block is indented
apply_indent(new_lines, indent)
- formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
+ vim.schedule_wrap(formatter_cb)(err, idx, region, input_lines, new_lines)
end)
end
end
diff --git a/lua/conform/fs.lua b/lua/conform/fs.lua
index d303dbd..c33a2dc 100644
--- a/lua/conform/fs.lua
+++ b/lua/conform/fs.lua
@@ -15,4 +15,24 @@ M.join = function(...)
return table.concat({ ... }, M.sep)
end
+---@param filepath string
+---@return boolean
+M.exists = function(filepath)
+ local stat = uv.fs_stat(filepath)
+ return stat ~= nil and stat.type ~= nil
+end
+
+---@param filepath string
+---@return string?
+M.read_file = function(filepath)
+ if not M.exists(filepath) then
+ return nil
+ end
+ local fd = assert(uv.fs_open(filepath, "r", 420)) -- 0644
+ local stat = assert(uv.fs_fstat(fd))
+ local content = uv.fs_read(fd, stat.size)
+ uv.fs_close(fd)
+ return content
+end
+
return M
diff --git a/lua/conform/init.lua b/lua/conform/init.lua
index 3824b4b..4fc35b5 100644
--- a/lua/conform/init.lua
+++ b/lua/conform/init.lua
@@ -36,6 +36,7 @@ local M = {}
---@field inherit? boolean
---@field command? string|fun(self: conform.FormatterConfig, ctx: conform.Context): string
---@field prepend_args? string|string[]|fun(self: conform.FormatterConfig, ctx: conform.Context): string|string[]
+---@field format? fun(self: conform.LuaFormatterConfig, ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[])) Mutually exclusive with command
---@field options? table
---@class (exact) conform.FormatterMeta
@@ -569,6 +570,12 @@ M.get_formatter_config = function(formatter, bufnr)
if type(override) == "function" then
override = override(bufnr)
end
+ if override and override.command and override.format then
+ local msg =
+ string.format("Formatter '%s' cannot define both 'command' and 'format' function", formatter)
+ vim.notify_once(msg, vim.log.levels.ERROR)
+ return nil
+ end
---@type nil|conform.FormatterConfig
local config = override
@@ -581,7 +588,7 @@ M.get_formatter_config = function(formatter, bufnr)
config = mod_config
end
elseif override then
- if override.command then
+ if override.command or override.format then
config = override
else
local msg = string.format(
diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua
index aa64e19..62d158e 100644
--- a/lua/conform/runner.lua
+++ b/lua/conform/runner.lua
@@ -342,6 +342,7 @@ local function run_formatter(bufnr, formatter, config, ctx, input_lines, opts, c
end
log.debug("%s exited with code %d", formatter.name, code)
log.trace("Output lines: %s", output)
+ log.trace("%s stderr: %s", formatter.name, stderr)
callback(nil, output)
else
log.info("%s exited with code %d", formatter.name, code)