Skip to content

Commit

Permalink
perf: skip over column separators while collecting metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
SCJangra committed Sep 1, 2024
1 parent fa34cb6 commit b343c69
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 107 deletions.
2 changes: 1 addition & 1 deletion lua/table-nvim/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

---@type TableNvimConfig
local uconf = {
padd_column_separators = true, -- Insert a space around column separators.
padd_column_separators = false, -- Insert a space around column separators.
mappings = {
next = '<TAB>', -- Go to next cell.
prev = '<S-TAB>', -- Go to previous cell.
Expand Down
167 changes: 62 additions & 105 deletions lua/table-nvim/md_table.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ local api, ts = vim.api, vim.treesitter
---@field cursor_col number The current column position of the cursor.
---@field cursor_row number The currow row position of the cursor.
---@field cols MdTableColInfo[] Information about all columns in the table.
---@field root TSNode the root node of the table.
---@field root TSNode The root node of the table.
---@field pipes boolean Whether the table is surrounded by pipes.
local MdTable = {}

---@param root TSNode The root node of a table.
---@return MdTable
function MdTable:new(root)
assert(utils.is_tbl_root(root), 'not a table root node')

local config = conf.get_config()

local cursor_pos = api.nvim_win_get_cursor(0)
local cursor_row, cursor_col = cursor_pos[1] - 1, cursor_pos[2]

Expand All @@ -41,12 +40,21 @@ function MdTable:new(root)
local rows = {}
local cursor_col_index = nil
local cursor_row_index = nil
local pipes = false

for r, row in utils.iter_children(root) do
local c_count = row:child_count()
for r, row in utils.iter_named_children(root) do
rows[r] = {}

for c, col in utils.iter_children(row) do
if r == 1 then
local col = row:child(0)

if col then
_, indent = col:start()
if col:type() == '|' then pipes = true end
end
end

for c, col in utils.iter_named_children(row) do
cols[c] = cols[c] or {}

local text = ts.get_node_text(col, 0):match('^%s*(.-)%s*$')
Expand All @@ -59,22 +67,7 @@ function MdTable:new(root)
cursor_row_index, cursor_col_index = r, c
end

cols[c].is_delimiter = type == C.CELL_PIPE

if config.padd_column_separators and type == C.CELL_PIPE then
if c == 1 then
text = '| '
elseif c == c_count then
text = ' |'
else
text = ' | '
end

width = #text
end

if r == 1 then
if c == 1 then _, indent = col:start() end
cols[c].max_width = width
elseif r == C.DELIMITER_ROW then
if type == C.CELL_LEFT then
Expand All @@ -101,9 +94,10 @@ function MdTable:new(root)
indent = indent,
cols = cols,
rows = rows,
cursor_col = cursor_col_index or 1,
cursor_row = cursor_row_index or 1,
cursor_col = cursor_col_index or #cols,
cursor_row = cursor_row_index or #rows,
root = root,
pipes = pipes,
}

---@diagnostic disable-next-line: inject-field
Expand All @@ -125,22 +119,42 @@ end
---@param index number The index of the row to render.
---@return string
function MdTable:render_row(index)
local padd = conf.get_config().padd_column_separators
local line = {}
local row = self.rows[index]

for c, cell in ipairs(self.rows[index]) do
if c == 1 then table.insert(line, string.rep(' ', self.indent)) end
table.insert(line, string.rep(' ', self.indent))

if self.pipes then
local del = padd and '| ' or '|'
table.insert(line, del)
end

local cell = row[1]
table.insert(line, self:cell_text(cell.type, cell.text, 1))

local len = #row

for c = 2, len do
cell = row[c]
local del = padd and ' | ' or '|'

table.insert(line, del)
table.insert(line, self:cell_text(cell.type, cell.text, c))
end

if self.pipes then
local del = padd and ' |' or '|'
table.insert(line, del)
end

return table.concat(line)
end

---Get the type of cell for a given string.
---@param text string
---@return CellType
function MdTable:cell_type(text)
if text == '|' then return C.CELL_PIPE end

if text:match('^:%-+$') then return C.CELL_LEFT end
if text:match('^%-+:$') then return C.CELL_RIGHT end
if text:match('^:%-+:$') then return C.CELL_CENTER end
Expand All @@ -155,8 +169,6 @@ end
---@param cell number Cell index in the row.
---@return string
function MdTable:cell_text(type, text, cell)
if type == C.CELL_PIPE then return self:delimiter(cell) end

if type == C.CELL_LEFT then
local hyphens = string.rep('-', self.cols[cell].max_width - 1)
return ':' .. hyphens
Expand Down Expand Up @@ -196,79 +208,27 @@ function MdTable:cell_text(type, text, cell)
return table.concat({ string.rep(' ', left_padding), text, string.rep(' ', right_padding) })
end

---Get delimiter for the given cell
---@param cell_index number
function MdTable:delimiter(cell_index)
local padd = conf.get_config().padd_column_separators

if cell_index == 1 then return padd and '| ' or '|' end
if cell_index == #self.cols then return padd and ' |' or '|' end

return padd and ' | ' or '|'
end

---Extend a row (to a given length) by inserting new cells at the end.
---@param row MdTableCell[] The row to extend.
---@param len number The length to extend to.
function MdTable:extend_row_to(row, len)
for index = #row + 1, len do
local text = row[index - 1].type == C.CELL_PIPE and ' ' or self:delimiter(index)
row[index] = { type = C.CELL_TEXT, text = text }
row[index] = { type = C.CELL_TEXT, text = ' ' }
end
end

---Generate a new cell for the given row and column index.
---@param row_index number
---@param col_index number
---@return MdTableCell, MdTableCell
---@return MdTableCell
function MdTable:gen_cell_for(row_index, col_index)
local left = self.rows[row_index][col_index - 1]
local current = self.rows[row_index][col_index]

local left_is_pipe = left and left.type == C.CELL_PIPE
local current_is_pipe = current and current.type == C.CELL_PIPE

local text = function()
local cell_delimiter = { text = '-', type = C.CELL_DELIMITER }
local cell_x = { text = 'x', type = C.CELL_TEXT }
local cell_space = { text = ' ', type = C.CELL_TEXT }
local cell_delimiter = { text = '-', type = C.CELL_DELIMITER }
local cell_x = { text = 'x', type = C.CELL_TEXT }
local cell_space = { text = ' ', type = C.CELL_TEXT }

if col_index == 1 then
if row_index == C.DELIMITER_ROW then return cell_delimiter else return cell_x end
end

if col_index == 2 and left_is_pipe then
if row_index == C.DELIMITER_ROW then return cell_delimiter else return cell_x end
end

if col_index == #self.cols and current_is_pipe then
if row_index == C.DELIMITER_ROW then return cell_delimiter else return cell_x end
end

if col_index == #self.cols + 1 then
if row_index == C.DELIMITER_ROW then return cell_delimiter else return cell_x end
end

if row_index == 1 then return cell_x end
if row_index == C.DELIMITER_ROW then return cell_delimiter end

return cell_space
end

local delimiter = { text = self:delimiter(col_index), type = C.CELL_PIPE }

if left == nil and current_is_pipe then
return text(), delimiter
elseif left == nil and not current_is_pipe then
return delimiter, text()
elseif left and left_is_pipe then
return delimiter, text()
elseif left and not left_is_pipe then
return text(), delimiter
else
-- This branch should be unreachable.
---@diagnostic disable-next-line: missing-return
end
if row_index == C.DELIMITER_ROW then return cell_delimiter end
if row_index == 1 or col_index == 1 or col_index == #self.cols + 1 then return cell_x end
return cell_space
end

---Insert a column to the table at the given index.
Expand All @@ -277,19 +237,16 @@ function MdTable:insert_column_at(index)
for i, row in ipairs(self.rows) do
self:extend_row_to(row, index - 1)

local first, second = self:gen_cell_for(i, index)
local cell = self:gen_cell_for(i, index)

table.insert(row, index, first)
table.insert(row, index, second)
table.insert(row, index, cell)
end

local first, second = self:gen_cell_for(1, index)
local cell = self:gen_cell_for(1, index)

local a = { is_delimiter = first.type == C.CELL_PIPE, max_width = #first.text, alighment = C.ALIGN_NONE }
local b = { is_delimiter = second.type == C.CELL_PIPE, max_width = #second.text, alighment = C.ALIGN_NONE }
local col_info = { max_width = #cell.text, alighment = C.ALIGN_NONE }

table.insert(self.cols, index, a)
table.insert(self.cols, index, b)
table.insert(self.cols, index, col_info)
end

---Insert a column to the left of current column.
Expand All @@ -310,17 +267,17 @@ function MdTable:insert_row_at(index)
local row = {}

local col_count = #self.cols
local cell_space = { type = C.CELL_TEXT, text = ' ' }
local cell_x = { type = C.CELL_TEXT, text = 'x' }

for c, col in ipairs(self.cols) do
if (c == 1 or c == 2 or c == col_count or c == col_count - 1) and not col.is_delimiter then
row[c] = { type = C.CELL_TEXT, text = 'x' }
elseif col.is_delimiter then
row[c] = { type = C.CELL_PIPE, text = '|' }
else
row[c] = { type = C.CELL_TEXT, text = ' ' }
end
row[1] = cell_x

for c = 2, col_count - 1 do
row[c] = cell_space
end

row[col_count] = cell_x

table.insert(self.rows, index, row)

return index
Expand Down
16 changes: 15 additions & 1 deletion lua/table-nvim/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ local tbl_align_left = 'pipe_table_align_left'
local tbl_align_right = 'pipe_table_align_right'
local tbl_node_len = #tbl_node

local api = vim.api
local conf = require('table-nvim.config')

---Returns `true` if the node is the root of a markdown table and `false` otherwise.
Expand All @@ -26,6 +25,8 @@ local get_tbl_root = function(node)
if node == nil then return nil end
if string.sub(node:type(), 1, tbl_node_len) ~= tbl_node then return nil end

if is_tbl_root(node) then return node end

while true do
node = node:parent()
if node == nil then return nil end
Expand Down Expand Up @@ -93,6 +94,18 @@ local iter_children = function(node)
end
end

---Iterate of all named children of a treesitter node.
---@param node TSNode
---@return fun(): integer?, TSNode?
local iter_named_children = function(node)
local n = node:named_child_count()
local i = -1
return function()
i = i + 1
if i < n then return i + 1, node:named_child(i) end
end
end

return {
get_tbl_root = get_tbl_root,
is_tbl_root = is_tbl_root,
Expand All @@ -102,4 +115,5 @@ return {
gen_table = gen_table,
gen_table_alt = gen_table_alt,
iter_children = iter_children,
iter_named_children = iter_named_children,
}

0 comments on commit b343c69

Please sign in to comment.