aboutsummaryrefslogtreecommitdiffstats
path: root/lua/conform/formatters/injected.lua
blob: c25f2f95bc19b3dc989b404453c2ff83571f1c35 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
---@param range? conform.Range
---@param start_lnum integer
---@param end_lnum integer
---@return boolean
local function in_range(range, start_lnum, end_lnum)
  return not range or (start_lnum <= range["end"][1] and range["start"][1] <= end_lnum)
end

---@type conform.FileLuaFormatterConfig
return {
  meta = {
    url = "lua/conform/formatters/injected.lua",
    description = "Format treesitter injected languages.",
  },
  condition = function(ctx)
    local ok = pcall(vim.treesitter.get_parser, ctx.buf)
    return ok
  end,
  format = function(ctx, lines, callback)
    local conform = require("conform")
    local util = require("conform.util")
    local ok, parser = pcall(vim.treesitter.get_parser, ctx.buf)
    if not ok then
      callback("No treesitter parser for buffer")
      return
    end
    local root_lang = parser:lang()
    local regions = {}
    for lang, child_lang in pairs(parser:children()) do
      local formatter_names = conform.formatters_by_ft[lang]
      if formatter_names and lang ~= root_lang then
        for _, tree in ipairs(child_lang:trees()) do
          local root = tree:root()
          local start_lnum = root:start() + 1
          local end_lnum = root:end_()
          if start_lnum <= end_lnum and in_range(ctx.range, start_lnum, end_lnum) then
            table.insert(regions, { lang, start_lnum, end_lnum })
          end
        end
      end
    end
    -- Sort from largest start_lnum to smallest
    table.sort(regions, function(a, b)
      return a[2] > b[2]
    end)

    local replacements = {}
    local format_error = nil

    local function apply_format_results()
      if format_error then
        callback(format_error)
        return
      end

      local formatted_lines = vim.deepcopy(lines)
      for _, replacement in ipairs(replacements) do
        local start_lnum, end_lnum, new_lines = unpack(replacement)
        for _ = start_lnum, end_lnum do
          table.remove(formatted_lines, start_lnum)
        end
        for i = #new_lines, 1, -1 do
          table.insert(formatted_lines, start_lnum, new_lines[i])
        end
      end
      callback(nil, formatted_lines)
    end

    local num_format = 0
    local formatter_cb = function(err, idx, start_lnum, end_lnum, new_lines)
      if err then
        format_error = err
      else
        replacements[idx] = { start_lnum, end_lnum, new_lines }
      end
      num_format = num_format - 1
      if num_format == 0 then
        apply_format_results()
      end
    end
    local last_start_lnum = #lines + 1
    for _, region in ipairs(regions) do
      local lang, start_lnum, end_lnum = unpack(region)
      -- Ignore regions that overlap (contain) other regions
      if end_lnum < last_start_lnum then
        num_format = num_format + 1
        last_start_lnum = start_lnum
        local input_lines = util.tbl_slice(lines, start_lnum, end_lnum)
        local formatter_names = conform.formatters_by_ft[lang]
        local format_opts = { async = true, bufnr = ctx.buf, quiet = true }
        local idx = num_format
        conform.format_lines(formatter_names, input_lines, format_opts, function(err, new_lines)
          formatter_cb(err, idx, start_lnum, end_lnum, new_lines)
        end)
      end
    end
  end,
}