Skip to content

Commit

Permalink
Fix for changes to nvim upstream TS/hl API, fix #241 (#242)
Browse files Browse the repository at this point in the history
* Fix for changes to nvim upstream TS/hl API, fix #241

re neovim/neovim#19931

* Work around bug in TS playground util

* Work around nvim TS highlighter.hl_map change

* Fix lint
  • Loading branch information
andymass authored Aug 27, 2022
1 parent e59d5c7 commit 950ef5d
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 42 deletions.
107 changes: 65 additions & 42 deletions lua/treesitter-matchup/third-party/hl-info.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,88 @@
-- licensed under the Apache License 2.0
-- See nvim-treesitter.LICENSE-APACHE-2.0

local highlighter = require("vim.treesitter.highlighter")
local utils = require "treesitter-matchup.third-party.utils"
local highlighter = require "vim.treesitter.highlighter"
local ts_utils = require "nvim-treesitter.ts_utils"

local M = {}

function M.get_treesitter_hl(cursor)
local buf = vim.api.nvim_get_current_buf()
local row, col = unpack(cursor or vim.api.nvim_win_get_cursor(0))
row = row - 1
if vim.treesitter.highlighter.hl_map then
function M.get_treesitter_hl(cursor)
local buf = vim.api.nvim_get_current_buf()
local row, col = unpack(cursor or vim.api.nvim_win_get_cursor(0))
row = row - 1

local self = highlighter.active[buf]
local self = highlighter.active[buf]

if not self then
return {}
end
if not self then
return {}
end

local matches = {}
local matches = {}

self.tree:for_each_tree(function(tstree, tree)
if not tstree then
return
end
self.tree:for_each_tree(function(tstree, tree)
if not tstree then
return
end

local root = tstree:root()
local root_start_row, _, root_end_row, _ = root:range()
local root = tstree:root()
local root_start_row, _, root_end_row, _ = root:range()

-- Only worry about trees within the line range
if root_start_row > row or root_end_row < row then
return
end
-- Only worry about trees within the line range
if root_start_row > row or root_end_row < row then
return
end

local query = self:get_query(tree:lang())
local query = self:get_query(tree:lang())

-- Some injected languages may not have highlight queries.
if not query:query() then
return
end
-- Some injected languages may not have highlight queries.
if not query:query() then
return
end

local iter = query:query():iter_captures(root, self.bufnr, row, row + 1)

for capture, node, _ in iter do
if ts_utils.is_in_node_range(node, row, col) then
local c = query._query.captures[capture] -- name of the capture in the query
if c ~= nil then
local general_hl, is_vim_hl = query:_get_hl_from_capture(capture)
local local_hl = not is_vim_hl and (tree:lang() .. general_hl)
local line = { c }
if local_hl then
table.insert(line, local_hl)
end
if general_hl and general_hl ~= local_hl then
table.insert(line, general_hl)
local iter = query:query():iter_captures(root, self.bufnr, row, row + 1)

for capture, node, _ in iter do
if ts_utils.is_in_node_range(node, row, col) then
local c = query._query.captures[capture] -- name of the capture in the query
if c ~= nil then
local general_hl, is_vim_hl = query:_get_hl_from_capture(capture)
local local_hl = not is_vim_hl and (tree:lang() .. general_hl)
local line = { c }
if local_hl then
table.insert(line, local_hl)
end
if general_hl and general_hl ~= local_hl then
table.insert(line, general_hl)
end
table.insert(matches, line)
end
table.insert(matches, line)
end
end
end, true)
return matches
end
else
function M.get_treesitter_hl(cursor)
local bufnr = vim.api.nvim_get_current_buf()
local row, col = unpack(cursor or vim.api.nvim_win_get_cursor(0))
row = row - 1

local results = utils.get_hl_groups_at_position(bufnr, row, col)
local highlights = {}
for _, hl in pairs(results) do
local line = { "@" .. hl.capture }
if hl.specific then
table.insert(line, hl.specific)
end
if hl.general then
table.insert(line, hl.general)
end
table.insert(highlights, line)
end
end, true)
return matches
return highlights
end
end

function M.active()
Expand Down
132 changes: 132 additions & 0 deletions lua/treesitter-matchup/third-party/utils.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
-- From https://github.com/nvim-treesitter/playground
-- Copyright 2021
-- licensed under the Apache License 2.0
-- See nvim-treesitter.LICENSE-APACHE-2.0

local api = vim.api
local ts_utils = require "nvim-treesitter.ts_utils"
local highlighter = require "vim.treesitter.highlighter"

local M = {}

function M.debounce(fn, debounce_time)
local timer = vim.loop.new_timer()
local is_debounce_fn = type(debounce_time) == "function"

return function(...)
timer:stop()

local time = debounce_time
local args = { ... }

if is_debounce_fn then
time = debounce_time()
end

timer:start(
time,
0,
vim.schedule_wrap(function()
fn(unpack(args))
end)
)
end
end

function M.get_hl_groups_at_position(bufnr, row, col)
local buf_highlighter = highlighter.active[bufnr]

if not buf_highlighter then
return {}
end

local matches = {}

buf_highlighter.tree:for_each_tree(function(tstree, tree)
if not tstree then
return
end

local root = tstree:root()
local root_start_row, _, root_end_row, _ = root:range()

-- Only worry about trees within the line range
if root_start_row > row or root_end_row < row then
return
end

local query = buf_highlighter:get_query(tree:lang())

-- Some injected languages may not have highlight queries.
if not query:query() then
return
end

local iter = query:query():iter_captures(root, buf_highlighter.bufnr, row, row + 1)

for capture, node, metadata in iter do
local hl = query.hl_cache[capture]

if hl and ts_utils.is_in_node_range(node, row, col) then
local c = query._query.captures[capture] -- name of the capture in the query
if c ~= nil then
local name = query._query.captures[capture]
local id = 0
if not vim.startswith(name, '_') then
id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. tree:lang())
end
table.insert(
matches,
{ capture = c, specific = id, general = id, priority = metadata.priority }
)
end
end
end
end, true)
return matches
end

function M.for_each_buf_window(bufnr, fn)
if not api.nvim_buf_is_loaded(bufnr) then
return
end

for _, window in ipairs(vim.fn.win_findbuf(bufnr)) do
fn(window)
end
end

function M.to_lookup_table(list, key_mapper)
local result = {}

for i, v in ipairs(list) do
local key = v

if key_mapper then
key = key_mapper(v, i)
end

result[key] = v
end

return result
end

function M.node_contains(node, range)
local start_row, start_col, end_row, end_col = node:range()
local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])

return start_fits and end_fits
end

--- Returns a tuple with the position of the last line and last column (0-indexed).
function M.get_end_pos(bufnr)
local bufnr = bufnr or api.nvim_get_current_buf()
local last_row = api.nvim_buf_line_count(bufnr) - 1
local last_line = api.nvim_buf_get_lines(bufnr, last_row, last_row + 1, true)[1]
local last_col = last_line and #last_line or 0
return last_row, last_col
end

return M

0 comments on commit 950ef5d

Please sign in to comment.