diff options
author | Steven Arcangeli <506791+stevearc@users.noreply.github.com> | 2023-09-29 11:56:21 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-29 11:56:21 -0700 |
commit | a5526fb2ee963cf426ab6d6ba1f3eb82887b1c22 (patch) | |
tree | 23961a5cbc439b67efac3c1f2b22b1fb97411172 /lua/conform | |
parent | 388d6e2440bccded26d5e67ce6a7039c1953ae70 (diff) |
feat: format injected languages (#83)
Diffstat (limited to 'lua/conform')
-rw-r--r-- | lua/conform/formatters/injected.lua | 98 | ||||
-rw-r--r-- | lua/conform/init.lua | 80 | ||||
-rw-r--r-- | lua/conform/runner.lua | 84 |
3 files changed, 242 insertions, 20 deletions
diff --git a/lua/conform/formatters/injected.lua b/lua/conform/formatters/injected.lua new file mode 100644 index 0000000..c25f2f9 --- /dev/null +++ b/lua/conform/formatters/injected.lua @@ -0,0 +1,98 @@ +---@param range? conform.Range +---@param start_lnum integer +---@param end_lnum integer +---@return boolean +local function in_range(range, start_lnum, end_lnum) + return not range or (start_lnum <= range["end"][1] and range["start"][1] <= end_lnum) +end + +---@type conform.FileLuaFormatterConfig +return { + meta = { + url = "lua/conform/formatters/injected.lua", + description = "Format treesitter injected languages.", + }, + condition = function(ctx) + local ok = pcall(vim.treesitter.get_parser, ctx.buf) + return ok + end, + format = function(ctx, lines, callback) + local conform = require("conform") + local util = require("conform.util") + local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf) + if not ok then + callback("No treesitter parser for buffer") + return + end + local root_lang = parser:lang() + local regions = {} + for lang, child_lang in pairs(parser:children()) do + local formatter_names = conform.formatters_by_ft[lang] + if formatter_names and lang ~= root_lang then + for _, tree in ipairs(child_lang:trees()) do + local root = tree:root() + local start_lnum = root:start() + 1 + local end_lnum = root:end_() + if start_lnum <= end_lnum and in_range(ctx.range, start_lnum, end_lnum) then + table.insert(regions, { lang, start_lnum, end_lnum }) + end + end + end + end + -- Sort from largest start_lnum to smallest + table.sort(regions, function(a, b) + return a[2] > b[2] + end) + + local replacements = {} + local format_error = nil + + local function apply_format_results() + if format_error then + callback(format_error) + return + end + + local formatted_lines = vim.deepcopy(lines) + for _, replacement in ipairs(replacements) do + local start_lnum, end_lnum, new_lines = unpack(replacement) + for _ = start_lnum, end_lnum do + table.remove(formatted_lines, start_lnum) + end + for i = #new_lines, 1, -1 do + table.insert(formatted_lines, start_lnum, new_lines[i]) + end + end + callback(nil, formatted_lines) + end + + local num_format = 0 + local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines) + if err then + format_error = err + else + replacements[idx] = { start_lnum, end_lnum, new_lines } + end + num_format = num_format - 1 + if num_format == 0 then + apply_format_results() + end + end + local last_start_lnum = #lines + 1 + for _, region in ipairs(regions) do + local lang, start_lnum, end_lnum = unpack(region) + -- Ignore regions that overlap (contain) other regions + if end_lnum < last_start_lnum then + num_format = num_format + 1 + last_start_lnum = start_lnum + local input_lines = util.tbl_slice(lines, start_lnum, end_lnum) + local formatter_names = conform.formatters_by_ft[lang] + local format_opts = { async = true, bufnr = ctx.buf, quiet = true } + local idx = num_format + conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines) + formatter_cb(err, idx, start_lnum, end_lnum, new_lines) + end) + end + end + end, +} diff --git a/lua/conform/init.lua b/lua/conform/init.lua index 271e33d..a2b5a4c 100644 --- a/lua/conform/init.lua +++ b/lua/conform/init.lua @@ -7,7 +7,7 @@ local M = {} ---@field available boolean ---@field available_msg? string ----@class (exact) conform.FormatterConfig +---@class (exact) conform.JobFormatterConfig ---@field command string|fun(ctx: conform.Context): string ---@field args? string|string[]|fun(ctx: conform.Context): string|string[] ---@field range_args? fun(ctx: conform.RangeContext): string|string[] @@ -18,9 +18,18 @@ local M = {} ---@field exit_codes? integer[] Exit codes that indicate success (default {0}) ---@field env? table<string, any>|fun(ctx: conform.Context): table<string, any> ----@class (exact) conform.FileFormatterConfig : conform.FormatterConfig +---@class (exact) conform.LuaFormatterConfig +---@field format fun(ctx: conform.Context, lines: string[], callback: fun(err: nil|string, new_lines: nil|string[])) +---@field condition? fun(ctx: conform.Context): boolean + +---@class (exact) conform.FileLuaFormatterConfig : conform.LuaFormatterConfig ---@field meta conform.FormatterMeta +---@class (exact) conform.FileFormatterConfig : conform.JobFormatterConfig +---@field meta conform.FormatterMeta + +---@alias conform.FormatterConfig conform.JobFormatterConfig|conform.LuaFormatterConfig + ---@class (exact) conform.FormatterMeta ---@field url string ---@field description string @@ -415,6 +424,56 @@ M.format = function(opts, callback) end end +---Process lines with formatters +---@private +---@param formatter_names string[] +---@param lines string[] +---@param opts? table +--- timeout_ms nil|integer Time in milliseconds to block for formatting. Defaults to 1000. No effect if async = true. +--- bufnr nil|integer use this as the working buffer (default 0) +--- async nil|boolean If true the method won't block. Defaults to false. If the buffer is modified before the formatter completes, the formatting will be discarded. +--- quiet nil|boolean Don't show any notifications for warnings or failures. Defaults to false. +---@param callback? fun(err: nil|string, lines: nil|string[]) Called once formatting has completed +---@return nil|string error Only present if async = false +---@return nil|string[] new_lines Only present if async = false +M.format_lines = function(formatter_names, lines, opts, callback) + ---@type {timeout_ms: integer, bufnr: integer, async: boolean, quiet: boolean} + opts = vim.tbl_extend("keep", opts or {}, { + timeout_ms = 1000, + bufnr = 0, + async = false, + quiet = false, + }) + callback = callback or function(_err, _lines) end + local log = require("conform.log") + local runner = require("conform.runner") + local formatters = resolve_formatters(formatter_names, opts.bufnr, not opts.quiet) + if vim.tbl_isempty(formatters) then + callback(nil, lines) + return + end + + ---@param err? conform.Error + ---@param new_lines? string[] + local function handle_err(err, new_lines) + if err then + local level = runner.level_for_code(err.code) + log.log(level, err.message) + end + local err_message = err and err.message + callback(err_message, new_lines) + end + + if opts.async then + runner.format_lines_async(opts.bufnr, formatters, nil, lines, handle_err) + else + local err, new_lines = + runner.format_lines_sync(opts.bufnr, formatters, opts.timeout_ms, nil, lines) + handle_err(err, new_lines) + return err and err.message, new_lines + end +end + ---Retrieve the available formatters for a buffer ---@param bufnr? integer ---@return conform.FormatterInfo[] @@ -508,13 +567,26 @@ M.get_formatter_info = function(formatter, bufnr) local ctx = require("conform.runner").build_context(bufnr, config) + local available = true + local available_msg = nil + if config.format then + if config.condition and not config.condition(ctx) then + available = false + available_msg = "Condition failed" + end + return { + name = formatter, + command = formatter, + available = available, + available_msg = available_msg, + } + end + local command = config.command if type(command) == "function" then command = command(ctx) end - local available = true - local available_msg = nil if vim.fn.executable(command) == 0 then available = false available_msg = "Command not found" diff --git a/lua/conform/runner.lua b/lua/conform/runner.lua index 33429e9..7b3e468 100644 --- a/lua/conform/runner.lua +++ b/lua/conform/runner.lua @@ -47,7 +47,7 @@ M.is_execution_error = function(code) end ---@param ctx conform.Context ----@param config conform.FormatterConfig +---@param config conform.JobFormatterConfig ---@return string|string[] M.build_cmd = function(ctx, config) local command = config.command @@ -255,8 +255,14 @@ local last_run_errored = {} ---@param ctx conform.Context ---@param input_lines string[] ---@param callback fun(err?: conform.Error, output?: string[]) ----@return integer job_id +---@return integer? job_id local function run_formatter(bufnr, formatter, config, ctx, input_lines, callback) + if config.format then + ---@cast config conform.LuaFormatterConfig + config.format(ctx, input_lines, callback) + return + end + ---@cast config conform.JobFormatterConfig local cmd = M.build_cmd(ctx, config) local cwd = nil if config.cwd then @@ -440,11 +446,6 @@ M.format_async = function(bufnr, formatters, range, callback) if bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end - local idx = 1 - local changedtick = vim.b[bufnr].changedtick - local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local input_lines = original_lines - local all_support_range_formatting = true -- kill previous jobs for buffer local prev_jid = vim.b[bufnr].conform_jid @@ -454,9 +455,18 @@ M.format_async = function(bufnr, formatters, range, callback) end end - local function run_next_formatter() - local formatter = formatters[idx] - if not formatter then + local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local changedtick = vim.b[bufnr].changedtick + M.format_lines_async( + bufnr, + formatters, + range, + original_lines, + function(err, output_lines, all_support_range_formatting) + if err then + return callback(err) + end + assert(output_lines) local new_changedtick = vim.b[bufnr].changedtick -- changedtick gets set to -1 when vim is exiting. We have an autocmd that should store it in -- last_changedtick before it is set to -1. @@ -473,9 +483,29 @@ M.format_async = function(bufnr, formatters, range, callback) ), }) else - M.apply_format(bufnr, original_lines, input_lines, range, not all_support_range_formatting) + M.apply_format(bufnr, original_lines, output_lines, range, not all_support_range_formatting) callback() end + end + ) +end + +---@param bufnr integer +---@param formatters conform.FormatterInfo[] +---@param range? conform.Range +---@param input_lines string[] +---@param callback fun(err?: conform.Error, output_lines?: string[], all_support_range_formatting?: boolean) +M.format_lines_async = function(bufnr, formatters, range, input_lines, callback) + if bufnr == 0 then + bufnr = vim.api.nvim_get_current_buf() + end + local idx = 1 + local all_support_range_formatting = true + + local function run_next_formatter() + local formatter = formatters[idx] + if not formatter then + callback(nil, input_lines, all_support_range_formatting) return end idx = idx + 1 @@ -503,9 +533,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range) if bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end - local start = uv.hrtime() / 1e6 local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local input_lines = original_lines -- kill previous jobs for buffer local prev_jid = vim.b[bufnr].conform_jid @@ -515,6 +543,29 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range) end end + local err, final_result, all_support_range_formatting = + M.format_lines_sync(bufnr, formatters, timeout_ms, range, original_lines) + if err then + return err + end + assert(final_result) + + M.apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting) +end + +---@param bufnr integer +---@param formatters conform.FormatterInfo[] +---@param timeout_ms integer +---@param range? conform.Range +---@return conform.Error? error +---@return string[]? output_lines +---@return boolean? all_support_range_formatting +M.format_lines_sync = function(bufnr, formatters, timeout_ms, range, input_lines) + if bufnr == 0 then + bufnr = vim.api.nvim_get_current_buf() + end + local start = uv.hrtime() / 1e6 + local all_support_range_formatting = true for _, formatter in ipairs(formatters) do local remaining = timeout_ms - (uv.hrtime() / 1e6 - start) @@ -541,7 +592,9 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range) end, 5) if not wait_result then - vim.fn.jobstop(jid) + if jid then + vim.fn.jobstop(jid) + end if wait_reason == -1 then return { code = M.ERROR_CODE.TIMEOUT, @@ -562,8 +615,7 @@ M.format_sync = function(bufnr, formatters, timeout_ms, range) input_lines = result end - local final_result = input_lines - M.apply_format(bufnr, original_lines, final_result, range, not all_support_range_formatting) + return nil, input_lines, all_support_range_formatting end return M |