aboutsummaryrefslogtreecommitdiffstats
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/conform/formatters/injected.lua98
-rw-r--r--lua/conform/init.lua80
-rw-r--r--lua/conform/runner.lua84
3 files changed, 242 insertions, 20 deletions
diff --git a/lua/conform/formatters/injected.lua b/lua/conform/formatters/injected.lua
new file mode 100644
index 0000000..c25f2f9
--- /dev/null
+++ b/lua/conform/formatters/injected.lua
@@ -0,0 +1,98 @@
+---@param range? conform.Range
+---@param start_lnum integer
+---@param end_lnum integer
+---@return boolean
+local function in_range(range, start_lnum, end_lnum)
+ return not range or (start_lnum <= range["end"][1] and range["start"][1] <= end_lnum)
+end
+
+---@type conform.FileLuaFormatterConfig
+return {
+ meta = {
+ url = "lua/conform/formatters/injected.lua",
+ description = "Format treesitter injected languages.",
+ },
+ condition = function(ctx)
+ local ok = pcall(vim.treesitter.get_parser, ctx.buf)
+ return ok
+ end,
+ format = function(ctx, lines, callback)
+ local conform = require("conform")
+ local util = require("conform.util")
+ local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
+ if not ok then
+ callback("No treesitter parser for buffer")
+ return
+ end
+ local root_lang = parser:lang()
+ local regions = {}
+ for lang, child_lang in pairs(parser:children()) do
+ local formatter_names = conform.formatters_by_ft[lang]
+ if formatter_names and lang ~= root_lang then
+ for _, tree in ipairs(child_lang:trees()) do
+ local root = tree:root()
+ local start_lnum = root:start() + 1
+ local end_lnum = root:end_()
+ if start_lnum <= end_lnum and in_range(ctx.range, start_lnum, end_lnum) then
+ table.insert(regions, { lang, start_lnum, end_lnum })
+ end
+ end
+ end
+ end
+ -- Sort from largest start_lnum to smallest
+ table.sort(regions, function(a, b)
+ return a[2] > b[2]
+ end)
+
+ local replacements = {}
+ local format_error = nil
+
+ local function apply_format_results()
+ if format_error then
+ callback(format_error)
+ return
+ end
+
+ local formatted_lines = vim.deepcopy(lines)
+ for _, replacement in ipairs(replacements) do
+ local start_lnum, end_lnum, new_lines = unpack(replacement)
+ for _ = start_lnum, end_lnum do
+ table.remove(formatted_lines, start_lnum)
+ end
+ for i = #new_lines, 1, -1 do
+ table.insert(formatted_lines, start_lnum, new_lines[i])
+ end
+ end
+ callback(nil, formatted_lines)
+ end
+
+ local num_format = 0
+ local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
+ if err then
+ format_error = err
+ else
+ replacements[idx] = { start_lnum, end_lnum, new_lines }
+ end
+ num_format = num_format - 1
+ if num_format == 0 then
+ apply_format_results()
+ end
+ end
+ local last_start_lnum = #lines + 1
+ for _, region in ipairs(regions) do
+ local lang, start_lnum, end_lnum = unpack(region)
+ -- Ignore regions that overlap (contain) other regions
+ if end_lnum < last_start_lnum then
+ num_format = num_format + 1
+ last_start_lnum = start_lnum
+ local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
+ local formatter_names = conform.formatters_by_ft[lang]
+ local format_opts = { async = true, bufnr = ctx.buf, quiet = true }
+ local idx = num_format
+ conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
+ formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
+ end)
+ end
+ end
+ end,
+}
diff --git a/lua/conform/init.lua b/lua/conform/init.lua
index 271e33d..a2b5a4c 100644
--- a/lua/conform/init.lua
+++ b/lua/conform/init.lua
@@ -7,7 +7,7 @@ local M = {}
---@field available boolean
---@field available_msg? string
----@class (exact) conform.FormatterConfig
+---@class (exact) conform.JobFormatterConfig
---@field command string|fun(ctx: conform.Context): string
---@field args? string|string[]|fun(ctx: conform.Context): string|string[]
---@field range_args? fun(ctx: conform.RangeContext): string|string[]
@@ -18,9 +18,18 @@ local M = {}
---@field exit_codes? integer[] Exit codes that indicate success (default {0})
---@field env? table<string, any>|fun(ctx: conform.Context): table<string, any>
----@class (exact) conform.FileFormatterConfig : conform.FormatterConfig
+---@class (exact) conform.LuaFormatterConfig
+---@field format fun(ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[]))
+---@field condition? fun(ctx: conform.Context): boolean
+
+---@class (exact) conform.FileLuaFormatterConfig : conform.LuaFormatterConfig
---@field meta conform.FormatterMeta
+---@class (exact) conform.FileFormatterConfig : conform.JobFormatterConfig
+---@field meta conform.FormatterMeta
+
+---@alias conform.FormatterConfig conform.JobFormatterConfig|conform.LuaFormatterConfig
+
---@class (exact) conform.FormatterMeta
---@field url string
---@field description string
@@ -415,6 +424,56 @@ M.format = function(opts, callback)
end
end
+---Process lines with formatters
+---@private
+---@param formatter_names string[]
+---@param lines string[]
+---@param opts? table
+--- timeout_ms nil|integer Time in milliseconds to block for formatting. Defaults to 1000. No effect if async = true.
+--- bufnr nil|integer use this as the working buffer (default 0)
+--- async nil|boolean If true the method won't block. Defaults to false. If the buffer is modified before the formatter completes, the formatting will be discarded.
+--- quiet nil|boolean Don't show any notifications for warnings or failures. Defaults to false.
+---@param callback? fun(err: nil|string, lines: nil|string[]) Called once formatting has completed
+---@return nil|string error Only present if async = false
+---@return nil|string[] new_lines Only present if async = false
+M.format_lines = function(formatter_names, lines, opts, callback)
+ ---@type {timeout_ms: integer, bufnr: integer, async: boolean, quiet: boolean}
+ opts = vim.tbl_extend("keep", opts or {}, {
+ timeout_ms = 1000,
+ bufnr = 0,
+ async = false,
+ quiet = false,
+ })
+ callback = callback or function(_err, _lines) end
+ local log = require("conform.log")
+ local runner = require("conform.runner")
+ local formatters = resolve_formatters(formatter_names, opts.bufnr, not opts.quiet)
+ if vim.tbl_isempty(formatters) then
+ callback(nil, lines)
+ return
+ end
+
+ ---@param err? conform.Error
+ ---@param new_lines? string[]
+ local function handle_err(err, new_lines)
+ if err then
+ local level = runner.level_for_code(err.code)
+ log.log(level, err.message)
+ end
+ local err_message = err and err.message
+ callback(err_message, new_lines)
+ end
+
+ if opts.async then
+ runner.format_lines_async(opts.bufnr, formatters, nil, lines, handle_err)
+ else
+ local err, new_lines =
+ runner.format_lines_sync(opts.bufnr, formatters, opts.timeout_ms, nil, lines)
+ handle_err(err, new_lines)
+ return err and err.message, new_lines
+ end
+end
+
---Retrieve the available formatters for a buffer
---@param bufnr? integer
---@return conform.FormatterInfo[]
@@ -508,13 +567,26 @@ M.get_formatter_info = function(formatter, bufnr)
local ctx = require("conform.runner").build_context(bufnr, config)
+ local available = true
+ local available_msg = nil
+ if config.format then
+ if config.condition and not config.condition(ctx) then
+ available = false
+ available_msg = "Condition failed"
+ end
+ return {
+ name = formatter,
+ command = formatter,
+ available = available,
+ available_msg = available_msg,
+ }
+ end
+
local command = config.command
if type(command) == "function" then
command = command(ctx)
end
- local available = true
- local available_msg = nil
if vim.fn.executable(command) == 0 then
available = false
available_msg = "Command not found"
diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua
index 33429e9..7b3e468 100644
--- a/lua/conform/runner.lua
+++ b/lua/conform/runner.lua
@@ -47,7 +47,7 @@ M.is_execution_error = function(code)
end
---@param ctx conform.Context
----@param config conform.FormatterConfig
+---@param config conform.JobFormatterConfig
---@return string|string[]
M.build_cmd = function(ctx, config)
local command = config.command
@@ -255,8 +255,14 @@ local last_run_errored = {}
---@param ctx conform.Context
---@param input_lines string[]
---@param callback fun(err?: conform.Error, output?: string[])
----@return integer job_id
+---@return integer? job_id
local function run_formatter(bufnr, formatter, config, ctx, input_lines, callback)
+ if config.format then
+ ---@cast config conform.LuaFormatterConfig
+ config.format(ctx, input_lines, callback)
+ return
+ end
+ ---@cast config conform.JobFormatterConfig
local cmd = M.build_cmd(ctx, config)
local cwd = nil
if config.cwd then
@@ -440,11 +446,6 @@ M.format_async = function(bufnr, formatters, range, callback)
if bufnr == 0 then
bufnr = vim.api.nvim_get_current_buf()
end
- local idx = 1
- local changedtick = vim.b[bufnr].changedtick
- local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
- local input_lines = original_lines
- local all_support_range_formatting = true
-- kill previous jobs for buffer
local prev_jid = vim.b[bufnr].conform_jid
@@ -454,9 +455,18 @@ M.format_async = function(bufnr, formatters, range, callback)
end
end
- local function run_next_formatter()
- local formatter = formatters[idx]
- if not formatter then
+ local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
+ local changedtick = vim.b[bufnr].changedtick
+ M.format_lines_async(
+ bufnr,
+ formatters,
+ range,
+ original_lines,
+ function(err, output_lines, all_support_range_formatting)
+ if err then
+ return callback(err)
+ end
+ assert(output_lines)
local new_changedtick = vim.b[bufnr].changedtick
-- changedtick gets set to -1 when vim is exiting. We have an autocmd that should store it in
-- last_changedtick before it is set to -1.
@@ -473,9 +483,29 @@ M.format_async = function(bufnr, formatters, range, callback)
),
})
else
- M.apply_format(bufnr, original_lines, input_lines, range, not all_support_range_formatting)
+ M.apply_format(bufnr, original_lines, output_lines, range, not all_support_range_formatting)
callback()
end
+ end
+ )
+end
+
+---@param bufnr integer
+---@param formatters conform.FormatterInfo[]
+---@param range? conform.Range
+---@param input_lines string[]
+---@param callback fun(err?: conform.Error, output_lines?: string[], all_support_range_formatting?: boolean)
+M.format_lines_async = function(bufnr, formatters, range, input_lines, callback)
+ if bufnr == 0 then
+ bufnr = vim.api.nvim_get_current_buf()
+ end
+ local idx = 1
+ local all_support_range_formatting = true
+
+ local function run_next_formatter()
+ local formatter = formatters[idx]
+ if not formatter then
+ callback(nil, input_lines, all_support_range_formatting)
return
end
idx = idx + 1
@@ -503,9 +533,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range)
if bufnr == 0 then
bufnr = vim.api.nvim_get_current_buf()
end
- local start = uv.hrtime() / 1e6
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
- local input_lines = original_lines
-- kill previous jobs for buffer
local prev_jid = vim.b[bufnr].conform_jid
@@ -515,6 +543,29 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range)
end
end
+ local err, final_result, all_support_range_formatting =
+ M.format_lines_sync(bufnr, formatters, timeout_ms, range, original_lines)
+ if err then
+ return err
+ end
+ assert(final_result)
+
+ M.apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting)
+end
+
+---@param bufnr integer
+---@param formatters conform.FormatterInfo[]
+---@param timeout_ms integer
+---@param range? conform.Range
+---@return conform.Error? error
+---@return string[]? output_lines
+---@return boolean? all_support_range_formatting
+M.format_lines_sync = function(bufnr, formatters, timeout_ms, range, input_lines)
+ if bufnr == 0 then
+ bufnr = vim.api.nvim_get_current_buf()
+ end
+ local start = uv.hrtime() / 1e6
+
local all_support_range_formatting = true
for _, formatter in ipairs(formatters) do
local remaining = timeout_ms - (uv.hrtime() / 1e6 - start)
@@ -541,7 +592,9 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range)
end, 5)
if not wait_result then
- vim.fn.jobstop(jid)
+ if jid then
+ vim.fn.jobstop(jid)
+ end
if wait_reason == -1 then
return {
code = M.ERROR_CODE.TIMEOUT,
@@ -562,8 +615,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range)
input_lines = result
end
- local final_result = input_lines
- M.apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting)
+ return nil, input_lines, all_support_range_formatting
end
return M