1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
|
---@param range? conform.Range
---@param start_lnum integer
---@param end_lnum integer
---@return boolean
local function in_range(range, start_lnum, end_lnum)
return not range or (start_lnum <= range["end"][1] and range["start"][1] <= end_lnum)
end
---@param lines string[]
---@param language? string The language of the buffer
---@return string?
local function get_indent(lines, language)
local indent = nil
-- Handle markdown code blocks that are inside blockquotes
-- > ```lua
-- > local x = 1
-- > ```
local pattern = language == "markdown" and "^>?%s*" or "^%s*"
for _, line in ipairs(lines) do
if line ~= "" then
local whitespace = line:match(pattern)
if whitespace == "" then
return nil
elseif not indent or whitespace:len() < indent:len() then
indent = whitespace
end
end
end
return indent
end
---Remove leading indentation from lines and return the indentation string
---@param lines string[]
---@param language? string The language of the buffer
---@return string?
local function remove_indent(lines, language)
local indent = get_indent(lines, language)
if not indent then
return
end
local sub_start = indent:len() + 1
for i, line in ipairs(lines) do
if line ~= "" then
lines[i] = line:sub(sub_start)
end
end
return indent
end
---@param lines string[]?
---@param indentation string?
local function apply_indent(lines, indentation)
if not lines or not indentation then
return
end
for i, line in ipairs(lines) do
if line ~= "" then
lines[i] = indentation .. line
end
end
end
---@class LangRange
---@field [1] string language
---@field [2] integer start lnum
---@field [3] integer start col
---@field [4] integer end lnum
---@field [5] integer end col
---@param ranges LangRange[]
---@param range LangRange
local function accum_range(ranges, range)
local last_range = ranges[#ranges]
if last_range then
if last_range[1] == range[1] and last_range[4] == range[2] and last_range[5] == range[3] then
last_range[4] = range[4]
last_range[5] = range[5]
return
end
end
table.insert(ranges, range)
end
---@class (exact) conform.InjectedFormatterOptions
---@field ignore_errors boolean
---@field lang_to_ext table<string, string>
---@field lang_to_formatters table<string, conform.FiletypeFormatter>
---@type conform.FileLuaFormatterConfig
return {
meta = {
url = "doc/advanced_topics.md#injected-language-formatting-code-blocks",
description = "Format treesitter injected languages.",
},
options = {
-- Set to true to ignore errors
ignore_errors = false,
-- Map of treesitter language to file extension
-- A temporary file name with this extension will be generated during formatting
-- because some formatters care about the filename.
lang_to_ext = {
bash = "sh",
c_sharp = "cs",
elixir = "exs",
javascript = "js",
julia = "jl",
latex = "tex",
markdown = "md",
python = "py",
ruby = "rb",
rust = "rs",
teal = "tl",
typescript = "ts",
},
-- Map of treesitter language to formatters to use
-- (defaults to the value from formatters_by_ft)
lang_to_formatters = {},
},
condition = function(self, ctx)
local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
-- Require Neovim 0.9 because the treesitter API has changed significantly
---@diagnostic disable-next-line: invisible
return ok and parser._injection_query and vim.fn.has("nvim-0.9") == 1
end,
format = function(self, ctx, lines, callback)
local conform = require("conform")
local errors = require("conform.errors")
local log = require("conform.log")
local util = require("conform.util")
-- Need to add a trailing newline; some parsers need this.
-- For example, if a markdown code block ends at the end of the file, a trailing newline is
-- required otherwise the ``` will be grabbed as part of the injected block
local text = table.concat(lines, "\n") .. "\n"
local buf_lang = vim.treesitter.language.get_lang(vim.bo[ctx.buf].filetype)
local ok, parser = pcall(vim.treesitter.get_string_parser, text, buf_lang)
if not ok then
callback("No treesitter parser for buffer")
return
end
---@type conform.InjectedFormatterOptions
local options = self.options
---@param lang string
---@return nil|conform.FiletypeFormatter
local function get_formatters(lang)
return options.lang_to_formatters[lang] or conform.formatters_by_ft[lang]
end
--- Disable diagnostic to pass the typecheck github action
--- This is available on nightly, but not on stable
--- Stable doesn't have any parameters, so it's safe to always pass `false`
---@diagnostic disable-next-line: redundant-parameter
parser:parse(false)
local root_lang = parser:lang()
---@type LangRange[]
local regions = {}
for lang, lang_tree in pairs(parser:children()) do
if lang ~= root_lang then
for _, ranges in ipairs(lang_tree:included_regions()) do
for _, region in ipairs(ranges) do
local formatters = get_formatters(lang)
if formatters ~= nil then
-- The types are wrong. included_regions should be Range[][] not integer[][]
---@diagnostic disable-next-line: param-type-mismatch
local start_row, start_col, _, end_row, end_col, _ = unpack(region)
accum_range(regions, { lang, start_row + 1, start_col, end_row + 1, end_col })
end
end
end
end
end
if ctx.range then
regions = vim.tbl_filter(function(region)
return in_range(ctx.range, region[2], region[4])
end, regions)
end
-- Sort from largest start_lnum to smallest
table.sort(regions, function(a, b)
return a[2] > b[2]
end)
log.trace("Injected formatter regions %s", regions)
local replacements = {}
local format_error = nil
local function apply_format_results()
if format_error then
-- Find all of the conform errors in the replacements table and remove them
local i = 1
while i <= #replacements do
if replacements[i].code then
table.remove(replacements, i)
else
i = i + 1
end
end
if options.ignore_errors then
format_error = nil
end
end
local formatted_lines = vim.deepcopy(lines)
for _, replacement in ipairs(replacements) do
local start_lnum, start_col, end_lnum, end_col, new_lines = unpack(replacement)
local prefix = formatted_lines[start_lnum]:sub(1, start_col)
local suffix = formatted_lines[end_lnum]:sub(end_col + 1)
new_lines[1] = prefix .. new_lines[1]
new_lines[#new_lines] = new_lines[#new_lines] .. suffix
for _ = start_lnum, end_lnum do
table.remove(formatted_lines, start_lnum)
end
for i = #new_lines, 1, -1 do
table.insert(formatted_lines, start_lnum, new_lines[i])
end
end
callback(format_error, formatted_lines)
end
local num_format = 0
local tmp_bufs = {}
local formatter_cb = function(err, idx, region, input_lines, new_lines)
if err then
format_error = errors.coalesce(format_error, err)
replacements[idx] = err
else
-- If the original lines started/ended with a newline, preserve that newline.
-- Many formatters will trim them, but they're important for the document structure.
if input_lines[1] == "" and new_lines[1] ~= "" then
table.insert(new_lines, 1, "")
end
if input_lines[#input_lines] == "" and new_lines[#new_lines] ~= "" then
table.insert(new_lines, "")
end
replacements[idx] = { region[2], region[3], region[4], region[5], new_lines }
end
num_format = num_format - 1
if num_format == 0 then
for buf in pairs(tmp_bufs) do
vim.api.nvim_buf_delete(buf, { force = true })
end
apply_format_results()
end
end
local last_start_lnum = #lines + 1
for i, region in ipairs(regions) do
local lang = region[1]
local start_lnum = region[2]
local start_col = region[3]
local end_lnum = region[4]
local end_col = region[5]
-- Ignore regions that overlap (contain) other regions
if end_lnum < last_start_lnum then
num_format = num_format + 1
last_start_lnum = start_lnum
local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
input_lines[#input_lines] = input_lines[#input_lines]:sub(1, end_col)
if start_col > 0 then
input_lines[1] = input_lines[1]:sub(start_col + 1)
end
local ft_formatters = assert(get_formatters(lang))
---@type string[]
local formatter_names
if type(ft_formatters) == "function" then
formatter_names = ft_formatters(ctx.buf)
else
local formatters = require("conform").resolve_formatters(ft_formatters, ctx.buf, false)
formatter_names = vim.tbl_map(function(f)
return f.name
end, formatters)
end
local idx = num_format
log.debug("Injected format %s:%d:%d: %s", lang, start_lnum, end_lnum, formatter_names)
log.trace("Injected format lines %s", input_lines)
local indent = remove_indent(input_lines, buf_lang)
-- Create a temporary buffer. This is only needed because some formatters rely on the file
-- extension to determine a run mode (see https://github.com/stevearc/conform.nvim/issues/194)
-- This is using the language name as the file extension, but that is a reasonable
-- approximation for now. We can add special cases as the need arises.
local extension = options.lang_to_ext[lang] or lang
local buf =
vim.fn.bufadd(string.format("%s.%d.%s", vim.api.nvim_buf_get_name(ctx.buf), i, extension))
-- Actually load the buffer to set the buffer context which is required by some formatters such as `filetype`
vim.fn.bufload(buf)
tmp_bufs[buf] = true
local format_opts = { async = true, bufnr = buf, quiet = true }
conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
log.trace("Injected %s:%d:%d formatted lines %s", lang, start_lnum, end_lnum, new_lines)
-- Preserve indentation in case the code block is indented
apply_indent(new_lines, indent)
vim.schedule_wrap(formatter_cb)(err, idx, region, input_lines, new_lines)
end)
end
end
if num_format == 0 then
apply_format_results()
end
end,
-- TODO this is kind of a hack. It's here to ensure all_support_range_formatting is set properly.
-- Should figure out a better way to do this.
range_args = true,
}
|