aboutsummaryrefslogtreecommitdiffstats
path: root/lua/conform/formatters/injected.lua
blob: 4990e05720470aae0374d4232d963ea8322cd848 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
---@param range? conform.Range
---@param start_lnum integer
---@param end_lnum integer
---@return boolean
local function in_range(range, start_lnum, end_lnum)
  return not range or (start_lnum <= range["end"][1] and range["start"][1] <= end_lnum)
end

---@param lines string[]
---@return string?
local function get_indent(lines)
  local indent = nil
  for _, line in ipairs(lines) do
    if line ~= "" then
      local whitespace = line:match("^%s*")
      if whitespace == "" then
        return nil
      elseif not indent or whitespace:len() < indent:len() then
        indent = whitespace
      end
    end
  end
  return indent
end

---@param original_lines string[]
---@param new_lines? string[]
local function apply_indent(original_lines, new_lines)
  if not new_lines then
    return
  end
  local indent = get_indent(original_lines)
  local already_indented = get_indent(new_lines)
  if not indent or already_indented then
    return
  end
  for i, line in ipairs(new_lines) do
    if line ~= "" then
      new_lines[i] = indent .. line
    end
  end
end

---@type conform.FileLuaFormatterConfig
return {
  meta = {
    url = "lua/conform/formatters/injected.lua",
    description = "Format treesitter injected languages.",
  },
  condition = function(ctx)
    local ok = pcall(vim.treesitter.get_parser, ctx.buf)
    return ok
  end,
  format = function(ctx, lines, callback)
    local conform = require("conform")
    local util = require("conform.util")
    local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
    if not ok then
      callback("No treesitter parser for buffer")
      return
    end
    local root_lang = parser:lang()
    local regions = {}
    for lang, child_lang in pairs(parser:children()) do
      local formatter_names = conform.formatters_by_ft[lang]
      if formatter_names and lang ~= root_lang then
        for _, tree in ipairs(child_lang:trees()) do
          local root = tree:root()
          local start_lnum = root:start() + 1
          local end_lnum = root:end_()
          if start_lnum <= end_lnum and in_range(ctx.range, start_lnum, end_lnum) then
            table.insert(regions, { lang, start_lnum, end_lnum })
          end
        end
      end
    end
    -- Sort from largest start_lnum to smallest
    table.sort(regions, function(a, b)
      return a[2] > b[2]
    end)

    local replacements = {}
    local format_error = nil

    local function apply_format_results()
      if format_error then
        callback(format_error)
        return
      end

      local formatted_lines = vim.deepcopy(lines)
      for _, replacement in ipairs(replacements) do
        local start_lnum, end_lnum, new_lines = unpack(replacement)
        for _ = start_lnum, end_lnum do
          table.remove(formatted_lines, start_lnum)
        end
        for i = #new_lines, 1, -1 do
          table.insert(formatted_lines, start_lnum, new_lines[i])
        end
      end
      callback(nil, formatted_lines)
    end

    local num_format = 0
    local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
      if err then
        format_error = err
      else
        replacements[idx] = { start_lnum, end_lnum, new_lines }
      end
      num_format = num_format - 1
      if num_format == 0 then
        apply_format_results()
      end
    end
    local last_start_lnum = #lines + 1
    for _, region in ipairs(regions) do
      local lang, start_lnum, end_lnum = unpack(region)
      -- Ignore regions that overlap (contain) other regions
      if end_lnum < last_start_lnum then
        num_format = num_format + 1
        last_start_lnum = start_lnum
        local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
        local formatter_names = conform.formatters_by_ft[lang]
        local format_opts = { async = true, bufnr = ctx.buf, quiet = true }
        local idx = num_format
        conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
          -- Preserve indentation in case the code block is indented
          apply_indent(input_lines, new_lines)
          formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
        end)
      end
    end
  end,
}