Skip to content

Commit

Permalink
Improve error handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
aduros committed Dec 11, 2022
1 parent 7463cec commit d9211e3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
30 changes: 17 additions & 13 deletions lua/_ai/commands.lua
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,7 @@ function M.ai (args)
-- virt_text = {{"🤖", nil}},
})

local function on_result (result)
local text = result.choices[1].text
local lines = {}
for line in text:gmatch("[^\n]+") do
table.insert(lines, line)
end

-- -- Special case: prepend \n if we're dealing with a multi-line response
-- if #lines > 1 then
-- table.insert(lines, 1, "")
-- end

local function on_result (err, result)
local mark = vim.api.nvim_buf_get_extmark_by_id(buffer, ns_id, mark_id, { details = true })
local start_row = mark[1]
local start_col = mark[2]
Expand All @@ -67,7 +56,22 @@ function M.ai (args)

vim.api.nvim_buf_del_extmark(buffer, ns_id, mark_id)

vim.api.nvim_buf_set_text(buffer, start_row, start_col, end_row, end_col, lines)
if err then
vim.api.nvim_err_writeln("ai.vim: " .. err)
else
local text = result.choices[1].text
local lines = {}
for line in text:gmatch("[^\n]+") do
table.insert(lines, line)
end

-- -- Special case: prepend \n if we're dealing with a multi-line response
-- if #lines > 1 then
-- table.insert(lines, 1, "")
-- end

vim.api.nvim_buf_set_text(buffer, start_row, start_col, end_row, end_col, lines)
end
end

if visual_mode then
Expand Down
61 changes: 44 additions & 17 deletions lua/_ai/openai.lua
Original file line number Diff line number Diff line change
@@ -1,50 +1,77 @@
local M = {}

function exec (cmd, args, on_stdout)
local chunks = {}
local function on_read (err, data)
function exec (cmd, args, on_result)
local stdout = vim.loop.new_pipe()
local stdout_chunks = {}
local function on_stdout_read (err, data)
if data then
table.insert(chunks, data)
table.insert(stdout_chunks, data)
end
end

local stdout = vim.loop.new_pipe()
local stderr = vim.loop.new_pipe()
local stderr_chunks = {}
local function on_stderr_read (err, data)
if data then
table.insert(stderr_chunks, data)
end
end

-- print(cmd, vim.inspect(args))

local handle

handle = vim.loop.spawn(cmd, {
handle, error = vim.loop.spawn(cmd, {
args = args,
stdio = {nil, stdout, nil},
stdio = {nil, stdout, stderr},
}, function (code, signal)
stdout:close()
stderr:close()
handle:close()

vim.schedule(function ()
local output = table.concat(chunks, "")
on_stdout(output)
if code ~= 0 then
-- Lop off the trailing newline character
on_result(table.concat(stderr_chunks, ""):sub(0, -2))
else
on_result(nil, table.concat(stdout_chunks, ""))
end
end)
end)

stdout:read_start(on_read)
if not handle then
on_result(cmd .. " could not be started: " .. error)
else
stdout:read_start(on_stdout_read)
stderr:read_start(on_stderr_read)
end
end

function M.call (endpoint, body, on_result)
local api_key = os.getenv("OPENAI_API_KEY")
assert(api_key ~= nil, "$OPENAI_API_KEY environment variable must be set")
if not api_key then
on_result("$OPENAI_API_KEY environment variable must be set")
return
end

local curl_args = {
"-X", "POST", "--silent", "--show-error",
"-L", "https://api.openai.com/v1/" .. endpoint,
"-H", "Content-Type: application/json",
"-H", "Authorization: Bearer " .. api_key,
"-d", vim.json.encode(body),
}

-- print("Calling API:", endpoint, vim.json.encode(body))
exec("curl", curl_args, function (output)
local json = vim.json.decode(output)
if json.error then
print("API error:", json.error.message)
exec("curl", curl_args, function (err, output)
if err then
on_result(err)
else
on_result(json)
local json = vim.json.decode(output)
if json.error then
on_result(json.error.message)
else
on_result(nil, json)
end
end
end)
end
Expand Down

0 comments on commit d9211e3

Please sign in to comment.