Last active
July 7, 2025 12:57
-
-
Save arnm/41774efdeb49bd377e6865a800fca3bc to your computer and use it in GitHub Desktop.
CodeCompanion Chat Strategies
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Directory Structure: | |
└── ./ | |
└── lua | |
└── codecompanion | |
└── strategies | |
└── chat | |
├── agents | |
│ ├── tools | |
│ │ ├── helpers | |
│ │ │ ├── diff.lua | |
│ │ │ ├── patch.lua | |
│ │ │ └── wait.lua | |
│ │ ├── cmd_runner.lua | |
│ │ ├── create_file.lua | |
│ │ ├── file_search.lua | |
│ │ ├── grep_search.lua | |
│ │ ├── insert_edit_into_file.lua | |
│ │ ├── next_edit_suggestion.lua | |
│ │ ├── read_file.lua | |
│ │ └── web_search.lua | |
│ ├── init.lua | |
│ └── tool_filter.lua | |
├── slash_commands | |
│ ├── buffer.lua | |
│ ├── fetch.lua | |
│ ├── file.lua | |
│ ├── help.lua | |
│ ├── image.lua | |
│ ├── init.lua | |
│ ├── keymaps.lua | |
│ ├── now.lua | |
│ ├── quickfix.lua | |
│ ├── symbols.lua | |
│ ├── terminal.lua | |
│ └── workspace.lua | |
├── variables | |
│ ├── buffer.lua | |
│ ├── init.lua | |
│ ├── lsp.lua | |
│ ├── user.lua | |
│ └── viewport.lua | |
├── debug.lua | |
├── helpers.lua | |
├── init.lua | |
├── references.lua | |
├── subscribers.lua | |
└── tools.lua | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/helpers/diff.lua | |
--- | |
local config = require("codecompanion.config") | |
local keymaps = require("codecompanion.utils.keymaps") | |
local ui = require("codecompanion.utils.ui") | |
local api = vim.api | |
local M = {} | |
---Create a diff for a buffer and set up keymaps | |
---@param bufnr number The buffer to create diff for | |
---@param diff_id number|string Unique identifier for this diff | |
---@param opts? table Optional configuration | |
---@return table|nil diff The diff object, or nil if no diff was created | |
function M.create(bufnr, diff_id, opts) | |
opts = opts or {} | |
-- Skip if in auto mode or diff disabled | |
if vim.g.codecompanion_auto_tool_mode or not config.display.diff.enabled then | |
return nil | |
end | |
-- Skip for terminal buffers | |
if vim.bo[bufnr].buftype == "terminal" then | |
return nil | |
end | |
local provider = config.display.diff.provider | |
local ok, diff_module = pcall(require, "codecompanion.providers.diff." .. provider) | |
if not ok then | |
return nil | |
end | |
local winnr = ui.buf_get_win(bufnr) | |
if not winnr then | |
return nil | |
end | |
local diff_args = { | |
bufnr = bufnr, | |
contents = api.nvim_buf_get_lines(bufnr, 0, -1, true), | |
filetype = api.nvim_buf_get_option(bufnr, "filetype"), | |
id = diff_id, | |
winnr = winnr, | |
} | |
local diff = diff_module.new(diff_args) | |
M.setup_keymaps(diff, opts) | |
return diff | |
end | |
---Set up keymaps for the diff | |
---@param diff table The diff object | |
---@param opts? table Optional configuration | |
function M.setup_keymaps(diff, opts) | |
opts = opts or {} | |
local inline_config = config.strategies.inline | |
if not inline_config or not inline_config.keymaps then | |
return | |
end | |
keymaps | |
.new({ | |
bufnr = diff.bufnr, | |
callbacks = require("codecompanion.strategies.inline.keymaps"), | |
data = { diff = diff }, | |
keymaps = inline_config.keymaps, | |
}) | |
:set() | |
end | |
---Check if a diff should be created for this context | |
---@param bufnr number | |
---@return boolean should_create | |
---@return string|nil reason Why diff creation was skipped | |
function M.should_create(bufnr) | |
if vim.g.codecompanion_auto_tool_mode then | |
return false, "auto_tool_mode" | |
end | |
if not config.display.diff.enabled then | |
return false, "diff_disabled" | |
end | |
if vim.bo[bufnr].buftype == "terminal" then | |
return false, "terminal_buffer" | |
end | |
return true, nil | |
end | |
return M | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/helpers/patch.lua | |
--- | |
local M = {} | |
M.FORMAT_PROMPT = [[*** Begin Patch | |
[PATCH] | |
*** End Patch | |
The `[PATCH]` is the series of diffs to be applied for each change in the file. Each diff should be in this format: | |
[3 lines of pre-context] | |
-[old code] | |
+[new code] | |
[3 lines of post-context] | |
The context blocks are 3 lines of existing code, immediately before and after the modified lines of code. | |
Lines to be modified should be prefixed with a `+` or `-` sign. | |
Unmodified lines used in context should begin with an empty space ` `. | |
For example, to add a subtract method to a calculator class in Python: | |
*** Begin Patch | |
def add(self, value): | |
self.result += value | |
return self.result | |
+def subtract(self, value): | |
+ self.result -= value | |
+ return self.result | |
+ | |
def multiply(self, value): | |
self.result *= value | |
return self.result | |
*** End Patch | |
Multiple blocks of diffs should be separated by an empty line and `@@[identifier]` as detailed below. | |
The immediately preceding and after context lines are enough to locate the lines to edit. DO NOT USE line numbers anywhere in the patch. | |
You can use `@@[identifier]` to define a larger context in case the immediately before and after context is not sufficient to locate the edits. Example: | |
@@class BaseClass(models.Model): | |
[3 lines of pre-context] | |
- pass | |
+ raise NotImplementedError() | |
[3 lines of post-context] | |
You can also use multiple `@@[identifiers]` to provide the right context if a single `@@` is not sufficient. | |
Example with multiple blocks of changes and `@@` identifiers: | |
*** Begin Patch | |
@@class BaseClass(models.Model): | |
@@ def search(): | |
- pass | |
+ raise NotImplementedError() | |
@@class Subclass(BaseClass): | |
@@ def search(): | |
- pass | |
+ raise NotImplementedError() | |
*** End Patch | |
This format is similar to the `git diff` format; the difference is that `@@[identifiers]` uses the unique line identifiers from the preceding code instead of line numbers. We don't use line numbers anywhere since the before and after context, and `@@` identifiers are enough to locate the edits. | |
IMPORTANT: Be mindful that the user may have shared attachments that contain line numbers, but these should NEVER be used in your patch. Always use the contextual format described above.]] | |
---@class Change | |
---@field focus string[] Identifiers or lines for providing large context before a change | |
---@field pre string[] Unchanged lines immediately before edits | |
---@field old string[] Lines to be removed | |
---@field new string[] Lines to be added | |
---@field post string[] Unchanged lines just after edits | |
---Create and return a new (empty) Change table instance. | |
---@param focus? string[] Optional focus lines for context | |
---@param pre? string[] Optional pre-context lines | |
---@return Change New change object | |
local function get_new_change(focus, pre) | |
return { | |
focus = focus or {}, | |
pre = pre or {}, | |
old = {}, | |
new = {}, | |
post = {}, | |
} | |
end | |
---Parse a patch string into a list of Change objects. | |
---@param patch string Patch containing the changes | |
---@return Change[] List of parsed change blocks | |
local function parse_changes_from_patch(patch) | |
local changes = {} | |
local change = get_new_change() | |
local lines = vim.split(patch, "\n", { plain = true }) | |
for i, line in ipairs(lines) do | |
if vim.startswith(line, "@@") then | |
if #change.old > 0 or #change.new > 0 then | |
-- @@ after any edits is a new change block | |
table.insert(changes, change) | |
change = get_new_change() | |
end | |
-- focus name can be empty too to signify new blocks | |
local focus_name = vim.trim(line:sub(3)) | |
if focus_name and #focus_name > 0 then | |
change.focus[#change.focus + 1] = focus_name | |
end | |
elseif line == "" and lines[i + 1] and lines[i + 1]:match("^@@") then | |
-- empty lines can be part of pre/post context | |
-- we treat empty lines as new change block and not as post context | |
-- only when the next line uses @@ identifier | |
table.insert(changes, change) | |
change = get_new_change() | |
elseif line:sub(1, 1) == "-" then | |
if #change.post > 0 then | |
-- edits after post edit lines are new block of changes with same focus | |
table.insert(changes, change) | |
change = get_new_change(change.focus, change.post) | |
end | |
change.old[#change.old + 1] = line:sub(2) | |
elseif line:sub(1, 1) == "+" then | |
if #change.post > 0 then | |
-- edits after post edit lines are new block of changes with same focus | |
table.insert(changes, change) | |
change = get_new_change(change.focus, change.post) | |
end | |
change.new[#change.new + 1] = line:sub(2) | |
elseif #change.old == 0 and #change.new == 0 then | |
change.pre[#change.pre + 1] = line | |
elseif #change.old > 0 or #change.new > 0 then | |
change.post[#change.post + 1] = line | |
end | |
end | |
table.insert(changes, change) | |
return changes | |
end | |
---Parse the full raw string from LLM for all patches, returning all Change objects parsed. | |
---@param raw string Raw text containing patch blocks | |
---@return Change[], boolean All parsed Change objects, and whether the patch was properly parsed | |
function M.parse_changes(raw) | |
local patches = {} | |
for patch in raw:gmatch("%*%*%* Begin Patch%s+(.-)%s+%*%*%* End Patch") do | |
table.insert(patches, patch) | |
end | |
local had_begin_end_markers = true | |
if #patches == 0 then | |
--- LLMs miss the begin / end markers sometimes | |
--- let's assume the raw content was correctly wrapped in these cases | |
--- setting a `markers_error` so that we can show this error in case the patch fails to apply | |
had_begin_end_markers = false | |
table.insert(patches, raw) | |
end | |
local all_changes = {} | |
for _, patch in ipairs(patches) do | |
local changes = parse_changes_from_patch(patch) | |
for _, change in ipairs(changes) do | |
table.insert(all_changes, change) | |
end | |
end | |
return all_changes, had_begin_end_markers | |
end | |
---Score how many lines from needle match haystack lines. | |
---@param haystack string[] All file lines | |
---@param pos integer Starting index to check (1-based) | |
---@param needle string[] Lines to match | |
---@return integer Score: 10 per perfect line, or 9 per trimmed match | |
local function get_score(haystack, pos, needle) | |
local score = 0 | |
for i, needle_line in ipairs(needle) do | |
local hayline = haystack[pos + i - 1] | |
if hayline == needle_line then | |
score = score + 10 | |
elseif hayline and vim.trim(hayline) == vim.trim(needle_line) then | |
score = score + 9 | |
end | |
end | |
return score | |
end | |
---Compute the match score for focus lines above a position. | |
---@param lines string[] Lines of source file | |
---@param before_pos integer Scan up to this line (exclusive; 1-based) | |
---@param focus string[] Focus lines/context | |
---@return integer Score: 20 per matching focus line before position | |
local function get_focus_score(lines, before_pos, focus) | |
local start = 1 | |
local score = 0 | |
for _, focus_line in ipairs(focus) do | |
for k = start, before_pos - 1 do | |
if focus_line == lines[k] or (vim.trim(focus_line) == vim.trim(lines[k])) then | |
score = score + 20 | |
start = k | |
break | |
end | |
end | |
end | |
return score | |
end | |
---Get overall score for placing change at a given index. | |
---@param lines string[] File lines | |
---@param change Change To match | |
---@param i integer Line position | |
---@return number Score from 0.0 to 1.0 | |
local function get_match_score(lines, change, i) | |
local max_score = (#change.focus * 2 + #change.pre + #change.old + #change.post) * 10 | |
local score = get_focus_score(lines, i, change.focus) | |
+ get_score(lines, i - #change.pre, change.pre) | |
+ get_score(lines, i, change.old) | |
+ get_score(lines, i + #change.old, change.post) | |
return score / max_score | |
end | |
---Determine best insertion spot for a Change and its match score. | |
---@param lines string[] File lines | |
---@param change Change Patch block | |
---@return integer, number location (1-based), Score (0-1) | |
local function get_best_location(lines, change) | |
-- try applying patch in flexible spaces mode | |
-- there is no standardised way to of spaces in diffs | |
-- python differ specifies a single space after +/- | |
-- while gnu udiff uses no spaces | |
-- | |
-- and LLM models (especially Claude) sometimes strip | |
-- long spaces on the left in case of large nestings (eg html) | |
-- trim_spaces mode solves all of these | |
local best_location = 1 | |
local best_score = 0 | |
for i = 1, #lines + 1 do | |
local score = get_match_score(lines, change, i) | |
if score == 1 then | |
return i, 1 | |
end | |
if score > best_score then | |
best_location = i | |
best_score = score | |
end | |
end | |
return best_location, best_score | |
end | |
---Get the location where a change would be applied without actually applying it | |
---@param lines string[] File lines | |
---@param change Change Edit description | |
---@return integer|nil location The line number (1-based) where the change would be applied | |
function M.get_change_location(lines, change) | |
local location, score = get_best_location(lines, change) | |
if score < 0.5 then | |
return nil | |
end | |
return location | |
end | |
---Apply a Change object to the file lines. Returns nil if not confident. | |
---@param lines string[] Lines before patch | |
---@param change Change Edit description | |
---@return string[]|nil New file lines (or nil if patch can't be confidently placed) | |
function M.apply_change(lines, change) | |
local location, score = get_best_location(lines, change) | |
if score < 0.5 then | |
return | |
end | |
local new_lines = {} | |
-- add lines before diff | |
for k = 1, location - 1 do | |
new_lines[#new_lines + 1] = lines[k] | |
end | |
-- add new lines | |
local fix_spaces | |
-- infer adjustment of spaces from the delete line | |
if score ~= 1 and #change.old > 0 then | |
if change.old[1] == " " .. lines[location] then | |
-- diff patch added and extra space on left | |
fix_spaces = function(ln) | |
return ln:sub(2) | |
end | |
elseif #change.old[1] < #lines[location] then | |
-- diff removed spaces on left | |
local prefix = string.rep(" ", #lines[location] - #change.old[1]) | |
fix_spaces = function(ln) | |
return prefix .. ln | |
end | |
end | |
end | |
for _, ln in ipairs(change.new) do | |
if fix_spaces then | |
ln = fix_spaces(ln) | |
end | |
new_lines[#new_lines + 1] = ln | |
end | |
-- add remaining lines | |
for k = location + #change.old, #lines do | |
new_lines[#new_lines + 1] = lines[k] | |
end | |
return new_lines | |
end | |
---Join a list of lines, prefixing each optionally. | |
---@param list string[] List of lines | |
---@param sep string Separator (e.g., "\n") | |
---@param prefix? string Optional prefix for each line | |
---@return string|false Result string or false if list is empty | |
local function prefix_join(list, sep, prefix) | |
if #list == 0 then | |
return false | |
end | |
if prefix then | |
for i = 1, #list do | |
list[i] = prefix .. list[i] | |
end | |
end | |
return table.concat(list, sep) | |
end | |
---Format a Change block as a string for output or logs. | |
---@param change Change To render | |
---@return string Formatted string | |
function M.get_change_string(change) | |
local parts = { | |
prefix_join(change.focus, "\n", "@@"), | |
prefix_join(change.pre, "\n"), | |
prefix_join(change.old, "\n", "-"), | |
prefix_join(change.new, "\n", "+"), | |
prefix_join(change.post, "\n"), | |
} | |
local non_empty = {} | |
for _, part in ipairs(parts) do | |
if part then | |
table.insert(non_empty, part) | |
end | |
end | |
return table.concat(non_empty, "\n") | |
end | |
return M | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/helpers/wait.lua | |
--- | |
local config = require("codecompanion.config") | |
local ui = require("codecompanion.utils.ui") | |
local utils = require("codecompanion.utils") | |
local api = vim.api | |
local M = {} | |
---Wait for user decision on specific events | |
---@param id string|number Unique identifier for the decision context | |
---@param events table Name of events to wait for. First one is considered "accept" | |
---@param callback function Callback to execute when decision is made | |
---@param opts? table Optional configuration | |
function M.for_decision(id, events, callback, opts) | |
opts = opts or {} | |
-- Auto-approve if in auto mode | |
-- Generally, most tools will avoid us reaching this point, but it's a good fallback | |
if vim.g.codecompanion_auto_tool_mode then | |
return callback({ accepted = true }) | |
end | |
local aug = api.nvim_create_augroup("codecompanion_wait_" .. tostring(id), { clear = true }) | |
-- Show waiting indicator in the chat buffer | |
local chat_extmark_id = nil | |
if opts.chat_bufnr then | |
chat_extmark_id = M.show_waiting_indicator(opts.chat_bufnr, opts) | |
end | |
api.nvim_create_autocmd("User", { | |
group = aug, | |
pattern = events, | |
callback = function(event) | |
local event_data = event.data or {} | |
if id ~= event_data.id then | |
return | |
end | |
local accepted = (event.match == events[1]) | |
if chat_extmark_id and opts.chat_bufnr then | |
M.clear_waiting_indicator(opts.chat_bufnr) | |
end | |
api.nvim_clear_autocmds({ group = aug }) | |
callback({ | |
accepted = accepted, | |
event = event.match, | |
data = event_data, | |
}) | |
end, | |
}) | |
if opts.notify then | |
utils.notify(opts.notify or "Waiting for user decision ...") | |
end | |
opts.timeout = opts.timeout or config.strategies.chat.tools.opts.wait_timeout or 30000 | |
vim.defer_fn(function() | |
if chat_extmark_id and opts.chat_bufnr then | |
M.clear_waiting_indicator(opts.chat_bufnr) | |
end | |
api.nvim_clear_autocmds({ group = aug }) | |
callback({ | |
accepted = false, | |
timeout = true, | |
}) | |
end, opts.timeout) | |
end | |
---Show a waiting indicator in the chat buffer | |
---@param bufnr number The buffer number to show the indicator in | |
---@param opts table Options for the indicator | |
---@return number The extmark ID for cleanup | |
function M.show_waiting_indicator(bufnr, opts) | |
opts = opts or {} | |
local notify = opts.notify or "Waiting for user decision ..." | |
local sub_text = opts.sub_text | |
return ui.show_buffer_notification(bufnr, { | |
namespace = "codecompanion_waiting_" .. tostring(bufnr), | |
footer = true, | |
text = notify, | |
sub_text = sub_text, | |
main_hl = "CodeCompanionChatWarn", | |
sub_hl = "CodeCompanionChatSubtext", | |
}) | |
end | |
---Clear the waiting indicator | |
---@param bufnr number The buffer number to clear the indicator from | |
---@return nil | |
function M.clear_waiting_indicator(bufnr) | |
ui.clear_notification(bufnr, { namespace = "codecompanion_waiting_" .. tostring(bufnr) }) | |
end | |
return M | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/cmd_runner.lua | |
--- | |
--[[ | |
*Command Runner Tool* | |
This tool is used to run shell commands on your system | |
--]] | |
local util = require("codecompanion.utils") | |
---@class CodeCompanion.Tool.CmdRunner: CodeCompanion.Agent.Tool | |
return { | |
name = "cmd_runner", | |
cmds = { | |
-- This is dynamically populated via the setup function | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "cmd_runner", | |
description = "Run shell commands on the user's system, sharing the output with the user before then sharing with you.", | |
parameters = { | |
type = "object", | |
properties = { | |
cmd = { | |
type = "string", | |
description = "The command to run, e.g. `pytest` or `make test`", | |
}, | |
flag = { | |
anyOf = { | |
{ type = "string" }, | |
{ type = "null" }, | |
}, | |
description = 'If running tests, set to `"testing"`; null otherwise', | |
}, | |
}, | |
required = { | |
"cmd", | |
"flag", | |
}, | |
additionalProperties = false, | |
}, | |
strict = true, | |
}, | |
}, | |
system_prompt = string.format( | |
[[# Command Runner Tool (`cmd_runner`) | |
## CONTEXT | |
- You have access to a command runner tool running within CodeCompanion, in Neovim. | |
- You can use it to run shell commands on the user's system. | |
- You may be asked to run a specific command or to determine the appropriate command to fulfil the user's request. | |
- All tool executions take place in the current working directory %s. | |
## OBJECTIVE | |
- Follow the tool's schema. | |
- Respond with a single command, per tool execution. | |
## RESPONSE | |
- Only invoke this tool when the user specifically asks. | |
- If the user asks you to run a specific command, do so to the letter, paying great attention. | |
- Use this tool strictly for command execution; but file operations must NOT be executed in this tool unless the user explicitly approves. | |
- To run multiple commands, you will need to call this tool multiple times. | |
## SAFETY RESTRICTIONS | |
- Never execute the following dangerous commands under any circumstances: | |
- `rm -rf /` or any variant targeting root directories | |
- `rm -rf ~` or any command that could wipe out home directories | |
- `rm -rf .` without specific context and explicit user confirmation | |
- Any command with `:(){:|:&};:` or similar fork bombs | |
- Any command that would expose sensitive information (keys, tokens, passwords) | |
- Commands that intentionally create infinite loops | |
- For any destructive operation (delete, overwrite, etc.), always: | |
1. Warn the user about potential consequences | |
2. Request explicit confirmation before execution | |
3. Suggest safer alternatives when available | |
- If unsure about a command's safety, decline to run it and explain your concerns | |
## POINTS TO NOTE | |
- This tool can be used alongside other tools within CodeCompanion | |
## USER ENVIRONMENT | |
- Shell: %s | |
- Operating System: %s | |
- Neovim Version: %s]], | |
vim.fn.getcwd(), | |
vim.o.shell, | |
util.os(), | |
vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch | |
), | |
handlers = { | |
---@param self CodeCompanion.Tool.CmdRunner | |
---@param agent CodeCompanion.Agent The tool object | |
setup = function(self, agent) | |
local args = self.args | |
local cmd = { cmd = vim.split(args.cmd, " ") } | |
if args.flag then | |
cmd.flag = args.flag | |
end | |
table.insert(self.cmds, cmd) | |
end, | |
}, | |
output = { | |
---Prompt the user to approve the execution of the command | |
---@param self CodeCompanion.Tool.CmdRunner | |
---@param agent CodeCompanion.Agent | |
---@return string | |
prompt = function(self, agent) | |
return string.format("Run the command `%s`?", table.concat(self.cmds[1].cmd, " ")) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.CmdRunner | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
agent.chat:add_tool_output( | |
self, | |
string.format("The user rejected the execution of the command `%s`?", table.concat(self.cmds[1].cmd, " ")) | |
) | |
end, | |
---@param self CodeCompanion.Tool.CmdRunner | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local cmds = table.concat(cmd.cmd, " ") | |
local errors = vim.iter(stderr):flatten():join("\n") | |
local error_output = string.format( | |
[[**Cmd Runner Tool**: There was an error running the `%s` command: | |
```txt | |
%s | |
```]], | |
cmds, | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
if stdout and not vim.tbl_isempty(stdout) then | |
local output = string.format( | |
[[**Cmd Runner Tool**: There was also some output from the command: | |
```txt | |
%s | |
```]], | |
vim.iter(stdout):flatten():join("\n") | |
) | |
chat:add_tool_output(self, output) | |
end | |
end, | |
---@param self CodeCompanion.Tool.CmdRunner | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local chat = agent.chat | |
if stdout and vim.tbl_isempty(stdout) then | |
return chat:add_tool_output(self, "There was no output from the cmd_runner tool") | |
end | |
local output = vim.iter(stdout[#stdout]):flatten():join("\n") | |
local message = string.format( | |
[[**Cmd Runner Tool**: The output from the command `%s` was: | |
```txt | |
%s | |
```]], | |
table.concat(cmd.cmd, " "), | |
output | |
) | |
chat:add_tool_output(self, message) | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/create_file.lua | |
--- | |
local Path = require("plenary.path") | |
local log = require("codecompanion.utils.log") | |
local fmt = string.format | |
---Create a file and the surrounding folders | |
---@param action {filepath: string, content: string} The action containing the filepath and content | |
---@return {status: "success"|"error", data: string} | |
local function create(action) | |
local filepath = vim.fs.joinpath(vim.fn.getcwd(), action.filepath) | |
local p = Path:new(filepath) | |
p.filename = p:expand() | |
if p:exists() then | |
if p:is_dir() then | |
return { | |
status = "error", | |
data = fmt("**Create File Tool**: `%s` already exists as a directory", action.filepath), | |
} | |
else | |
return { | |
status = "error", | |
data = fmt("**Create File Tool**: File `%s` already exists", action.filepath), | |
} | |
end | |
end | |
local ok, result = pcall(function() | |
p:touch({ parents = true }) | |
p:write(action.content, "w") | |
end) | |
if not ok then | |
return { | |
status = "error", | |
data = fmt("**Create File Tool**: Failed to create file `%s` - %s", action.filepath, result), | |
} | |
end | |
return { | |
status = "success", | |
data = fmt("**Create File Tool**: `%s` was created successfully", action.filepath), | |
} | |
end | |
---@class CodeCompanion.Tool.CreateFile: CodeCompanion.Agent.Tool | |
return { | |
name = "create_file", | |
cmds = { | |
---Execute the file commands | |
---@param self CodeCompanion.Tool.CreateFile | |
---@param args table The arguments from the LLM's tool call | |
---@param input? any The output from the previous function call | |
---@return { status: "success"|"error", data: string } | |
function(self, args, input) | |
return create(args) | |
end, | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "create_file", | |
description = "This is a tool for creating a new file on the user's machine. The file will be created with the specified content, creating any necessary parent directories.", | |
parameters = { | |
type = "object", | |
properties = { | |
filepath = { | |
type = "string", | |
description = "The relative path to the file to create, including its filename and extension.", | |
}, | |
content = { | |
type = "string", | |
description = "The content to write to the file.", | |
}, | |
}, | |
required = { | |
"filepath", | |
"content", | |
}, | |
}, | |
}, | |
}, | |
handlers = { | |
---@param agent CodeCompanion.Agent The tool object | |
---@return nil | |
on_exit = function(agent) | |
log:trace("[Create File Tool] on_exit handler executed") | |
end, | |
}, | |
output = { | |
---The message which is shared with the user when asking for their approval | |
---@param self CodeCompanion.Agent.Tool | |
---@param agent CodeCompanion.Agent | |
---@return nil|string | |
prompt = function(self, agent) | |
local args = self.args | |
local filepath = vim.fn.fnamemodify(args.filepath, ":.") | |
return fmt("Create a file at %s?", filepath) | |
end, | |
---@param self CodeCompanion.Tool.CreateFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local chat = agent.chat | |
local llm_output = vim.iter(stdout):flatten():join("\n") | |
chat:add_tool_output(self, llm_output) | |
end, | |
---@param self CodeCompanion.Tool.CreateFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local args = self.args | |
local errors = vim.iter(stderr):flatten():join("\n") | |
log:debug("[Create File Tool] Error output: %s", stderr) | |
local error_output = fmt( | |
[[**Create File Tool**: Ran with an error: | |
```txt | |
%s | |
```]], | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.CreateFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
local chat = agent.chat | |
chat:add_tool_output(self, "**Create File Tool**: The user declined to execute") | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/file_search.lua | |
--- | |
local log = require("codecompanion.utils.log") | |
local fmt = string.format | |
---Search the current working directory for files matching the glob pattern. | |
---@param action { query: string, max_results: number } | |
---@param opts table | |
---@return { status: "success"|"error", data: string } | |
local function search(action, opts) | |
opts = opts or {} | |
local query = action.query | |
local max_results = action.max_results or opts.max_results or 500 -- Default limit to prevent overwhelming results | |
if not query or query == "" then | |
return { | |
status = "error", | |
data = "Query parameter is required and cannot be empty", | |
} | |
end | |
local cwd = vim.fn.getcwd() | |
-- Convert glob pattern to lpeg pattern for matching | |
local ok, glob_pattern = pcall(vim.glob.to_lpeg, query) | |
if not ok then | |
return { | |
status = "error", | |
data = fmt("Invalid glob pattern '%s': %s", query, glob_pattern), | |
} | |
end | |
-- Use vim.fs.find with a custom function that matches the glob pattern | |
local found_files = vim.fs.find(function(name, path) | |
local full_path = vim.fs.joinpath(path, name) | |
local relative_path = vim.fs.relpath(cwd, full_path) | |
if not relative_path then | |
return false | |
end | |
return glob_pattern:match(relative_path) ~= nil | |
end, { | |
limit = max_results, | |
type = "file", | |
path = cwd, | |
}) | |
if #found_files == 0 then | |
return { | |
status = "success", | |
data = fmt("No files found matching pattern '%s'", query), | |
} | |
end | |
-- Convert absolute paths to relative paths so the LLM doesn't have full knowledge of the filesystem | |
local relative_files = {} | |
for _, file in ipairs(found_files) do | |
local rel_path = vim.fs.relpath(cwd, file) | |
if rel_path then | |
table.insert(relative_files, rel_path) | |
else | |
table.insert(relative_files, file) | |
end | |
end | |
return { | |
status = "success", | |
data = relative_files, | |
} | |
end | |
---@class CodeCompanion.Tool.FileSearch: CodeCompanion.Agent.Tool | |
return { | |
name = "file_search", | |
cmds = { | |
---Execute the search commands | |
---@param self CodeCompanion.Tool.FileSearch | |
---@param args table The arguments from the LLM's tool call | |
---@param input? any The output from the previous function call | |
---@return { status: "success"|"error", data: string } | |
function(self, args, input) | |
return search(args, self.tool.opts) | |
end, | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "file_search", | |
description = "Search for files in the workspace by glob pattern. This only returns the paths of matching files. Use this tool when you know the exact filename pattern of the files you're searching for. Glob patterns match from the root of the workspace folder. Examples:\n- **/*.{js,ts} to match all js/ts files in the workspace.\n- src/** to match all files under the top-level src folder.\n- **/foo/**/*.js to match all js files under any foo folder in the workspace.", | |
parameters = { | |
type = "object", | |
properties = { | |
query = { | |
type = "string", | |
description = "Search for files with names or paths matching this glob pattern.", | |
}, | |
max_results = { | |
type = "number", | |
description = "The maximum number of results to return. Do not use this unless necessary, it can slow things down. By default, only some matches are returned. If you use this and don't see what you're looking for, you can try again with a more specific query or a larger max_results.", | |
}, | |
}, | |
required = { | |
"query", | |
}, | |
}, | |
}, | |
}, | |
handlers = { | |
---@param agent CodeCompanion.Agent The tool object | |
---@return nil | |
on_exit = function(agent) | |
log:trace("[File Search Tool] on_exit handler executed") | |
end, | |
}, | |
output = { | |
---The message which is shared with the user when asking for their approval | |
---@param self CodeCompanion.Agent.Tool | |
---@param agent CodeCompanion.Agent | |
---@return nil|string | |
prompt = function(self, agent) | |
local args = self.args | |
local query = args.query or "" | |
return fmt("Search the cwd for %s?", query) | |
end, | |
---@param self CodeCompanion.Tool.FileSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local chat = agent.chat | |
local data = stdout[1] | |
local llm_output = "<fileSearchTool>%s</fileSearchTool>" | |
local user_message = "**File Search Tool**: %s" | |
local output = vim.iter(stdout):flatten():join("\n") | |
if type(data) == "table" then | |
-- Files were found - data is an array of file paths | |
local files = #data | |
chat:add_tool_output( | |
self, | |
fmt(llm_output, fmt("Returning %d files matching the query:\n%s", files, output)), | |
fmt(user_message, fmt("Returned %d files", files)) | |
) | |
else | |
-- No files found - data is a string message | |
chat:add_tool_output(self, fmt(llm_output, "No files found"), fmt(user_message, "No files found")) | |
end | |
end, | |
---@param self CodeCompanion.Tool.FileSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local errors = vim.iter(stderr):flatten():join("\n") | |
log:debug("[File Search Tool] Error output: %s", stderr) | |
local error_output = fmt( | |
[[**File Search Tool**: Ran with an error: | |
```txt | |
%s | |
```]], | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.FileSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
local chat = agent.chat | |
chat:add_tool_output(self, "**File Search Tool**: The user declined to execute") | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/grep_search.lua | |
--- | |
local log = require("codecompanion.utils.log") | |
local fmt = string.format | |
---Search the current working directory for text using ripgrep | |
---@param action { query: string, is_regexp: boolean?, include_pattern: string? } | |
---@param opts table | |
---@return { status: "success"|"error", data: string|table } | |
local function grep_search(action, opts) | |
opts = opts or {} | |
local query = action.query | |
if not query or query == "" then | |
return { | |
status = "error", | |
data = "Query parameter is required and cannot be empty", | |
} | |
end | |
-- Check if ripgrep is available | |
if vim.fn.executable("rg") ~= 1 then | |
return { | |
status = "error", | |
data = "ripgrep (rg) is not installed or not in PATH", | |
} | |
end | |
local cmd = { "rg" } | |
local cwd = vim.fn.getcwd() | |
local max_results = opts.max_results or 100 | |
local is_regexp = action.is_regexp or false | |
local respect_gitignore = opts.respect_gitignore | |
if respect_gitignore == nil then | |
respect_gitignore = opts.respect_gitignore ~= false | |
end | |
-- Use JSON output for structured parsing | |
table.insert(cmd, "--json") | |
table.insert(cmd, "--line-number") | |
table.insert(cmd, "--no-heading") | |
table.insert(cmd, "--with-filename") | |
-- Regex vs fixed string | |
if not is_regexp then | |
table.insert(cmd, "--fixed-strings") | |
end | |
-- Case sensitivity | |
table.insert(cmd, "--ignore-case") | |
-- Gitignore handling | |
if not respect_gitignore then | |
table.insert(cmd, "--no-ignore") | |
end | |
-- File pattern filtering | |
if action.include_pattern and action.include_pattern ~= "" then | |
table.insert(cmd, "--glob") | |
table.insert(cmd, action.include_pattern) | |
end | |
-- Limit results per file - we'll limit total results in post-processing | |
table.insert(cmd, "--max-count") | |
table.insert(cmd, tostring(math.min(max_results, 50))) | |
-- Add the query | |
table.insert(cmd, query) | |
-- Add the search path | |
table.insert(cmd, cwd) | |
log:debug("[Grep Search Tool] Running command: %s", table.concat(cmd, " ")) | |
-- Execute | |
local result = vim | |
.system(cmd, { | |
text = true, | |
timeout = 30000, -- 30 second timeout | |
}) | |
:wait() | |
if result.code ~= 0 then | |
local error_msg = result.stderr or "Unknown error" | |
if result.code == 1 then | |
-- No matches found - this is not an error for ripgrep | |
return { | |
status = "success", | |
data = "No matches found for the query", | |
} | |
elseif result.code == 2 then | |
log:warn("[Grep Search Tool] Invalid arguments or regex: %s", error_msg) | |
return { | |
status = "error", | |
data = fmt("Invalid search pattern or arguments: %s", error_msg:match("^[^\n]*") or "Unknown error"), | |
} | |
else | |
log:error("[Grep Search Tool] Command failed with code %d: %s", result.code, error_msg) | |
return { | |
status = "error", | |
data = fmt("Search failed: %s", error_msg:match("^[^\n]*") or "Unknown error"), | |
} | |
end | |
end | |
local output = result.stdout or "" | |
if output == "" then | |
return { | |
status = "success", | |
data = "No matches found for the query", | |
} | |
end | |
-- Parse JSON output from ripgrep | |
local matches = {} | |
local count = 0 | |
for line in output:gmatch("[^\n]+") do | |
if count >= max_results then | |
break | |
end | |
local ok, json_data = pcall(vim.json.decode, line) | |
if ok and json_data.type == "match" then | |
local file_path = json_data.data.path.text | |
local line_number = json_data.data.line_number | |
-- Convert absolute path to relative path from cwd | |
local relative_path = vim.fs.relpath(cwd, file_path) or file_path | |
-- Extract just the filename and directory | |
local filename = vim.fn.fnamemodify(relative_path, ":t") | |
local dir_path = vim.fn.fnamemodify(relative_path, ":h") | |
-- Format: "filename:line directory_path" | |
local match_entry = fmt("%s:%d %s", filename, line_number, dir_path == "." and "" or dir_path) | |
table.insert(matches, match_entry) | |
count = count + 1 | |
end | |
end | |
if #matches == 0 then | |
return { | |
status = "success", | |
data = "No matches found for the query", | |
} | |
end | |
return { | |
status = "success", | |
data = matches, | |
} | |
end | |
---@class CodeCompanion.Tool.GrepSearch: CodeCompanion.Agent.Tool | |
return { | |
name = "grep_search", | |
cmds = { | |
---Execute the search commands | |
---@param self CodeCompanion.Tool.GrepSearch | |
---@param args table The arguments from the LLM's tool call | |
---@param input? any The output from the previous function call | |
---@return { status: "success"|"error", data: string|table } | |
function(self, args, input) | |
return grep_search(args, self.tool.opts) | |
end, | |
}, | |
schema = { | |
["function"] = { | |
name = "grep_search", | |
description = "Do a text search in the workspace. Use this tool when you know the exact string you're searching for.", | |
parameters = { | |
type = "object", | |
properties = { | |
query = { | |
type = "string", | |
description = "The pattern to search for in files in the workspace. Can be a regex or plain text pattern", | |
}, | |
is_regexp = { | |
type = "boolean", | |
description = "Whether the pattern is a regex. False by default.", | |
}, | |
include_pattern = { | |
type = "string", | |
description = "Search files matching this glob pattern. Will be applied to the relative path of files within the workspace.", | |
}, | |
}, | |
required = { | |
"query", | |
}, | |
}, | |
}, | |
type = "function", | |
}, | |
handlers = { | |
---@param agent CodeCompanion.Agent The tool object | |
---@return nil | |
on_exit = function(agent) | |
log:trace("[Grep Search Tool] on_exit handler executed") | |
end, | |
}, | |
output = { | |
---The message which is shared with the user when asking for their approval | |
---@param self CodeCompanion.Agent.Tool | |
---@param agent CodeCompanion.Agent | |
---@return nil|string | |
prompt = function(self, agent) | |
local args = self.args | |
local query = args.query or "" | |
return fmt("Perform a grep search for %s?", query) | |
end, | |
---@param self CodeCompanion.Tool.GrepSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local chat = agent.chat | |
local data = stdout[1] | |
local llm_output = [[<grepSearchTool>%s | |
NOTE: The output format is {filename}:{line number} {filepath}. For example: | |
init.lua:335 lua/codecompanion/strategies/chat/agents | |
Refers to line 335 of the init.lua file in the lua/codecompanion/strategies/chat/agents directory. | |
</grepSearchTool>]] | |
local user_message = "**Grep Search Tool**: %s" | |
local output = vim.iter(stdout):flatten():join("\n") | |
if type(data) == "table" then | |
-- Results were found - data is an array of file paths | |
local results = #data | |
chat:add_tool_output( | |
self, | |
fmt(llm_output, fmt("Returning %d results matching the query:\n%s", results, output)), | |
fmt(user_message, fmt("Returned %d results", results)) | |
) | |
else | |
-- No results found - data is a string message | |
chat:add_tool_output(self, fmt(llm_output, "No results found"), fmt(user_message, "No results found")) | |
end | |
end, | |
---@param self CodeCompanion.Tool.GrepSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local errors = vim.iter(stderr):flatten():join("\n") | |
log:debug("[Grep Search Tool] Error output: %s", stderr) | |
local error_output = fmt( | |
[[**Grep Search Tool**: Ran with an error: | |
```txt | |
%s | |
```]], | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.GrepSearch | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
local chat = agent.chat | |
chat:add_tool_output(self, "**Grep Search Tool**: The user declined to execute") | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/insert_edit_into_file.lua | |
--- | |
local Path = require("plenary.path") | |
local buffers = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
local diff = require("codecompanion.strategies.chat.agents.tools.helpers.diff") | |
local log = require("codecompanion.utils.log") | |
local patch = require("codecompanion.strategies.chat.agents.tools.helpers.patch") | |
local ui = require("codecompanion.utils.ui") | |
local wait = require("codecompanion.strategies.chat.agents.tools.helpers.wait") | |
local api = vim.api | |
local fmt = string.format | |
local PROMPT = [[<editFileInstructions> | |
Before editing a file, ensure you have its content via the provided context or read_file tool. | |
Use the insert_edit_into_file tool to modify files. | |
NEVER show the code edits to the user - only call the tool. The system will apply and display the changes. | |
For each file, give a short description of what needs to be changed, then use the insert_edit_into_file tools. You can use the tool multiple times in a response, and you can keep writing text after using a tool. | |
The insert_edit_into_file tool is very smart and can understand how to apply your edits to the user's files, you just need to follow the patch format instructions carefully and to the letter. | |
## Patch Format | |
]] .. patch.FORMAT_PROMPT .. [[ | |
The system uses fuzzy matching and confidence scoring so focus on providing enough context to uniquely identify the location. | |
</editFileInstructions>]] | |
---Edit code in a file | |
---@param action {filepath: string, code: string, explanation: string} The arguments from the LLM's tool call | |
---@return string | |
local function edit_file(action) | |
local filepath = vim.fs.joinpath(vim.fn.getcwd(), action.filepath) | |
local p = Path:new(filepath) | |
p.filename = p:expand() | |
if not p:exists() or not p:is_file() then | |
return fmt("**Insert Edit Into File Tool Error**: File '%s' does not exist or is not a file", action.filepath) | |
end | |
-- 1. extract list of changes from the code | |
local raw = action.code or "" | |
local changes, had_begin_end_markers = patch.parse_changes(raw) | |
-- 2. read file into lines | |
local content = p:read() | |
local lines = vim.split(content, "\n", { plain = true }) | |
-- 3. apply changes | |
for _, change in ipairs(changes) do | |
local new_lines = patch.apply_change(lines, change) | |
if new_lines == nil then | |
if had_begin_end_markers then | |
error(fmt("Bad/Incorrect diff:\n\n%s\n\nNo changes were applied", patch.get_change_string(change))) | |
else | |
error("Invalid patch format: missing Begin/End markers") | |
end | |
else | |
lines = new_lines | |
end | |
end | |
-- 4. write back | |
p:write(table.concat(lines, "\n"), "w") | |
-- 5. refresh the buffer if the file is open | |
local bufnr = vim.fn.bufnr(p.filename) | |
if bufnr ~= -1 and api.nvim_buf_is_loaded(bufnr) then | |
api.nvim_command("checktime " .. bufnr) | |
end | |
return fmt("**Insert Edit Into File Tool**: `%s` - %s", action.filepath, action.explanation) | |
end | |
---Edit code in a buffer | |
---@param bufnr number The buffer number to edit | |
---@param chat_bufnr number The chat buffer number | |
---@param action {filepath: string, code: string, explanation: string} The arguments from the LLM's tool call | |
---@param output_handler function The callback to call when done | |
---@param opts? table Additional options | |
---@return string | |
local function edit_buffer(bufnr, chat_bufnr, action, output_handler, opts) | |
opts = opts or {} | |
local should_diff | |
local diff_id = math.random(10000000) | |
if diff.should_create(bufnr) then | |
should_diff = diff.create(bufnr, diff_id) | |
end | |
-- Parse and apply patches to buffer | |
local raw = action.code or "" | |
local changes, had_begin_end_markers = patch.parse_changes(raw) | |
-- Get current buffer content as lines | |
local lines = api.nvim_buf_get_lines(bufnr, 0, -1, false) | |
-- Apply each change | |
local start_line = nil | |
for _, change in ipairs(changes) do | |
local new_lines = patch.apply_change(lines, change) | |
if new_lines == nil then | |
if had_begin_end_markers then | |
error(fmt("Bad/Incorrect diff:\n\n%s\n\nNo changes were applied", patch.get_change_string(change))) | |
else | |
error("Invalid patch format: missing Begin/End markers") | |
end | |
else | |
if not start_line then | |
start_line = patch.get_change_location(lines, change) | |
end | |
lines = new_lines | |
end | |
end | |
-- Update the buffer with the edited code | |
api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) | |
-- Scroll to the editing location | |
if start_line then | |
ui.scroll_to_line(bufnr, start_line) | |
end | |
-- Auto-save if enabled | |
if vim.g.codecompanion_auto_tool_mode then | |
log:info("[Insert Edit Into File Tool] Auto-saving buffer") | |
api.nvim_buf_call(bufnr, function() | |
vim.cmd("silent write") | |
end) | |
end | |
local success = { | |
status = "success", | |
data = fmt("**Insert Edit Into File Tool**: `%s` - %s", action.filepath, action.explanation), | |
} | |
if should_diff and opts.user_confirmation then | |
local accept = config.strategies.inline.keymaps.accept_change.modes.n | |
local reject = config.strategies.inline.keymaps.reject_change.modes.n | |
local wait_opts = { | |
chat_bufnr = chat_bufnr, | |
notify = config.display.icons.warning .. " Waiting for diff approval ...", | |
sub_text = fmt("`%s` - Accept changes / `%s` - Reject changes", accept, reject), | |
} | |
-- Wait for the user to accept or reject the edit | |
return wait.for_decision(diff_id, { "CodeCompanionDiffAccepted", "CodeCompanionDiffRejected" }, function(result) | |
if result.accepted then | |
return output_handler(success) | |
end | |
return output_handler({ | |
status = "error", | |
data = result.timeout and "User failed to accept the changes in time" or "User rejected the changes", | |
}) | |
end, wait_opts) | |
end | |
return output_handler(success) | |
end | |
---@class CodeCompanion.Tool.InsertEditIntoFile: CodeCompanion.Agent.Tool | |
return { | |
name = "insert_edit_into_file", | |
cmds = { | |
---Execute the edit commands | |
---@param self CodeCompanion.Agent | |
---@param args table The arguments from the LLM's tool call | |
---@param input? any The output from the previous function call | |
---@param output_handler function Async callback for completion | |
---@return nil | |
function(self, args, input, output_handler) | |
local bufnr = buffers.get_bufnr_from_filepath(args.filepath) | |
if bufnr then | |
return edit_buffer(bufnr, self.chat.bufnr, args, output_handler, self.tool.opts) | |
else | |
local ok, outcome = pcall(edit_file, args) | |
if not ok then | |
return output_handler({ status = "error", data = outcome }) | |
end | |
return output_handler({ status = "success", data = outcome }) | |
end | |
end, | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "insert_edit_into_file", | |
description = "Insert new code or modify existing code in a file. Use this tool once per file that needs to be modified, even if there are multiple changes for a file. The system is very smart and can understand how to apply your edits to the user's files if you follow the instructions.", | |
parameters = { | |
type = "object", | |
properties = { | |
explanation = { | |
type = "string", | |
description = "A short explanation of the code edit being made", | |
}, | |
filepath = { | |
type = "string", | |
description = "The path to the file to edit, including its filename and extension", | |
}, | |
code = { | |
type = "string", | |
description = "The code which follows the patch format", | |
}, | |
}, | |
required = { | |
"explanation", | |
"filepath", | |
"code", | |
}, | |
additionalProperties = false, | |
}, | |
strict = true, | |
}, | |
}, | |
system_prompt = PROMPT, | |
handlers = { | |
---The handler to determine whether to prompt the user for approval | |
---@param self CodeCompanion.Tool.InsertEditIntoFile | |
---@param agent CodeCompanion.Agent | |
---@param config table The tool configuration | |
---@return boolean | |
prompt_condition = function(self, agent, config) | |
local opts = config["insert_edit_into_file"].opts or {} | |
local args = self.args | |
local bufnr = buffers.get_bufnr_from_filepath(args.filepath) | |
if bufnr then | |
if opts.requires_approval.buffer then | |
return true | |
end | |
return false | |
end | |
if opts.requires_approval.file then | |
return true | |
end | |
return false | |
end, | |
---@param agent CodeCompanion.Agent The tool object | |
---@return nil | |
on_exit = function(agent) | |
log:trace("[Insert Edit Into File Tool] on_exit handler executed") | |
end, | |
}, | |
output = { | |
---The message which is shared with the user when asking for their approval | |
---@param self CodeCompanion.Tool.InsertEditIntoFile | |
---@param agent CodeCompanion.Agent | |
---@return nil|string | |
prompt = function(self, agent) | |
local args = self.args | |
local filepath = vim.fn.fnamemodify(args.filepath, ":.") | |
return fmt("Edit the file at %s?", filepath) | |
end, | |
---@param self CodeCompanion.Tool.InsertEditIntoFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local llm_output = vim.iter(stdout):flatten():join("\n") | |
agent.chat:add_tool_output(self, llm_output) | |
end, | |
---@param self CodeCompanion.Tool.InsertEditIntoFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local args = self.args | |
local errors = vim.iter(stderr):flatten():join("\n") | |
log:debug("[Insert Edit Into File Tool] Error output: %s", stderr) | |
local error_output = fmt( | |
[[**Insert Edit Into File Tool**: Ran with an error: | |
```txt | |
%s | |
```]], | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.InsertEditIntoFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
local chat = agent.chat | |
chat:add_tool_output(self, "**Insert Edit Into File Tool**: The user declined to execute") | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/next_edit_suggestion.lua | |
--- | |
local log = require("codecompanion.utils.log") | |
---@class CodeCompanion.Tool.NextEditSuggestion.Args | |
---@field filepath string | |
---@field line integer | |
---@alias jump_action fun(path: string):integer? | |
---@class CodeCompanion.Tool.NextEditSuggestion: CodeCompanion.Agent.Tool | |
return { | |
opts = { | |
---@type jump_action|string | |
jump_action = require("codecompanion.utils.ui").tabnew_reuse, | |
}, | |
name = "next_edit_suggestion", | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "next_edit_suggestion", | |
description = "Suggest a possible position in a file for the next edit.", | |
parameters = { | |
type = "object", | |
properties = { | |
filepath = { | |
type = "string", | |
description = "The relative path to the file to edit, including its filename and extension.", | |
}, | |
line = { | |
type = "integer", | |
description = "Line number for the next edit (0-based). Use -1 if you're not sure about it.", | |
}, | |
}, | |
required = { "filepath", "line" }, | |
additionalProperties = false, | |
}, | |
strict = true, | |
}, | |
}, | |
system_prompt = function(_) | |
return [[# Next Edit Suggestion Tool | |
## CONTEXT | |
When you suggest a change to the codebase, you may call this tool to jump to the position in the file. | |
## OBJECTIVE | |
- Follow the tool's schema. | |
- Respond with a single command, per tool execution. | |
## RESPONSE | |
- Only use this tool when you have been given paths to the files | |
- DO NOT make up paths that you are not given | |
- Only use this tool when there's an unambiguous position to jump to | |
- If there are multiple possible edits, ask the users to make a choice before jumping | |
- Pass -1 as the line number if you are not sure about the correct line number | |
- Consider the paths as **CASE SENSITIVE** | |
]] | |
end, | |
cmds = { | |
---@param self CodeCompanion.Agent | |
---@param args CodeCompanion.Tool.NextEditSuggestion.Args | |
---@return {status: "success"|"error", data: string} | |
function(self, args, _) | |
if type(args.filepath) == "string" then | |
args.filepath = vim.fs.normalize(args.filepath) | |
end | |
local stat = vim.uv.fs_stat(args.filepath) | |
if stat == nil or stat.type ~= "file" then | |
log:error("failed to jump to %s", args.filepath) | |
if stat then | |
log:error("file stat:\n%s", vim.inspect(stat)) | |
end | |
return { status = "error", data = "Invalid path: " .. tostring(args.filepath) } | |
end | |
if type(self.tool.opts.jump_action) == "string" then | |
local action_command = self.tool.opts.jump_action | |
---@type jump_action | |
self.tool.opts.jump_action = function(path) | |
vim.cmd(action_command .. " " .. path) | |
return vim.api.nvim_get_current_win() | |
end | |
end | |
local winnr = self.tool.opts.jump_action(args.filepath) | |
if args.line >= 0 and winnr then | |
local ok = pcall(vim.api.nvim_win_set_cursor, winnr, { args.line + 1, 0 }) | |
if not ok then | |
local bufnr = vim.api.nvim_win_get_buf(winnr) | |
return { | |
status = "error", | |
data = string.format( | |
"The jump to the file was successful, but This file only has %d lines. Unable to jump to line %d", | |
vim.api.nvim_buf_line_count(bufnr), | |
args.line | |
), | |
} | |
end | |
end | |
return { status = "success", data = "Jump successful!" } | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/read_file.lua | |
--- | |
local Path = require("plenary.path") | |
local log = require("codecompanion.utils.log") | |
local fmt = string.format | |
---Read the contents of a file | |
---@param action {filepath: string, start_line_number_base_zero: number, end_line_number_base_zero: number} The action containing the filepath | |
---@return {status: "success"|"error", data: string} | |
local function read(action) | |
local filepath = vim.fs.joinpath(vim.fn.getcwd(), action.filepath) | |
local p = Path:new(filepath) | |
p.filename = p:expand() | |
local exists, _ = p:exists() | |
if not exists then | |
return { | |
status = "error", | |
data = fmt("**Read File Tool**: File `%s` does not exist", action.filepath), | |
} | |
end | |
local lines = p:readlines() | |
local start_line_zero = tonumber(action.start_line_number_base_zero) | |
local end_line_zero = tonumber(action.end_line_number_base_zero) | |
local error_msg = nil | |
if not start_line_zero then | |
error_msg = | |
fmt("start_line_number_base_zero must be a valid number, got: %s", tostring(action.start_line_number_base_zero)) | |
elseif not end_line_zero then | |
error_msg = | |
fmt("end_line_number_base_zero must be a valid number, got: %s", tostring(action.end_line_number_base_zero)) | |
elseif start_line_zero < 0 then | |
error_msg = fmt("start_line_number_base_zero cannot be negative, got: %d", start_line_zero) | |
elseif end_line_zero < -1 then | |
error_msg = fmt("end_line_number_base_zero cannot be less than -1, got: %d", end_line_zero) | |
elseif start_line_zero >= #lines then | |
error_msg = fmt( | |
"start_line_number_base_zero (%d) is beyond file length. File `%s` has %d lines (0-%d)", | |
start_line_zero, | |
action.filepath, | |
#lines, | |
math.max(0, #lines - 1) | |
) | |
elseif end_line_zero ~= -1 and end_line_zero >= #lines then | |
error_msg = fmt( | |
"end_line_number_base_zero (%d) is beyond file length. File `%s` has %d lines (0-%d)", | |
end_line_zero, | |
action.filepath, | |
#lines, | |
math.max(0, #lines - 1) | |
) | |
elseif end_line_zero ~= -1 and start_line_zero > end_line_zero then | |
error_msg = fmt( | |
"Invalid line range - start_line_number_base_zero (%d) comes after end_line_number_base_zero (%d)", | |
start_line_zero, | |
end_line_zero | |
) | |
end | |
if error_msg then | |
return { | |
status = "error", | |
data = "**Read File Tool**: " .. error_msg, | |
} | |
end | |
-- Convert to 1-based indexing | |
local start_line = start_line_zero + 1 | |
local end_line = end_line_zero == -1 and #lines or end_line_zero + 1 | |
-- Extract the specified lines | |
local selected_lines = {} | |
for i = start_line, end_line do | |
table.insert(selected_lines, lines[i]) | |
end | |
local content = table.concat(selected_lines, "\n") | |
local file_ext = vim.fn.fnamemodify(p.filename, ":e") | |
local output = fmt( | |
[[**Read File Tool**: Lines %d to %d of `%s`: | |
````%s | |
%s | |
````]], | |
action.start_line_number_base_zero, | |
action.end_line_number_base_zero, | |
filepath, | |
file_ext, | |
content | |
) | |
return { | |
status = "success", | |
data = output, | |
} | |
end | |
---@class CodeCompanion.Tool.ReadFile: CodeCompanion.Agent.Tool | |
return { | |
name = "read_file", | |
cmds = { | |
---Execute the file commands | |
---@param self CodeCompanion.Tool.ReadFile | |
---@param args table The arguments from the LLM's tool call | |
---@param input? any The output from the previous function call | |
---@return { status: "success"|"error", data: string } | |
function(self, args, input) | |
return read(args) | |
end, | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "read_file", | |
description = "Read the contents of a file.\n\nYou must specify the line range you're interested in. If the file contents returned are insufficient for your task, you may call this tool again to retrieve more content.", | |
parameters = { | |
type = "object", | |
properties = { | |
filepath = { | |
type = "string", | |
description = "The relative path to the file to read, including its filename and extension.", | |
}, | |
start_line_number_base_zero = { | |
type = "number", | |
description = "The line number to start reading from, 0-based.", | |
}, | |
end_line_number_base_zero = { | |
type = "number", | |
description = "The inclusive line number to end reading at, 0-based. Use -1 to read until the end of the file.", | |
}, | |
}, | |
required = { | |
"filepath", | |
"start_line_number_base_zero", | |
"end_line_number_base_zero", | |
}, | |
}, | |
}, | |
}, | |
handlers = { | |
---@param agent CodeCompanion.Agent The tool object | |
---@return nil | |
on_exit = function(agent) | |
log:trace("[Read File Tool] on_exit handler executed") | |
end, | |
}, | |
output = { | |
---The message which is shared with the user when asking for their approval | |
---@param self CodeCompanion.Agent.Tool | |
---@param agent CodeCompanion.Agent | |
---@return nil|string | |
prompt = function(self, agent) | |
local args = self.args | |
local filepath = vim.fn.fnamemodify(args.filepath, ":.") | |
return fmt("Read %s?", filepath) | |
end, | |
---@param self CodeCompanion.Tool.ReadFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table The command that was executed | |
---@param stdout table The output from the command | |
success = function(self, agent, cmd, stdout) | |
local chat = agent.chat | |
local llm_output = vim.iter(stdout):flatten():join("\n") | |
chat:add_tool_output(self, llm_output) | |
end, | |
---@param self CodeCompanion.Tool.ReadFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@param stderr table The error output from the command | |
---@param stdout? table The output from the command | |
error = function(self, agent, cmd, stderr, stdout) | |
local chat = agent.chat | |
local args = self.args | |
local errors = vim.iter(stderr):flatten():join("\n") | |
log:debug("[Read File Tool] Error output: %s", stderr) | |
local error_output = fmt( | |
[[**Read File Tool**: Ran with an error: | |
```txt | |
%s | |
```]], | |
errors | |
) | |
chat:add_tool_output(self, error_output) | |
end, | |
---Rejection message back to the LLM | |
---@param self CodeCompanion.Tool.ReadFile | |
---@param agent CodeCompanion.Agent | |
---@param cmd table | |
---@return nil | |
rejected = function(self, agent, cmd) | |
local chat = agent.chat | |
chat:add_tool_output(self, "**Read File Tool**: The user declined to execute") | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tools/web_search.lua | |
--- | |
local adapters = require("codecompanion.adapters") | |
local client = require("codecompanion.http") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local fmt = string.format | |
---@class CodeCompanion.Tool.WebSearch: CodeCompanion.Agent.Tool | |
return { | |
name = "web_search", | |
cmds = { | |
---@param self CodeCompanion.Agent The Editor tool | |
---@param args table The arguments from the LLM's tool call | |
---@param cb function Callback for asynchronous calls | |
---@return nil|{ status: "success"|"error", data: string } | |
function(self, args, _, cb) | |
if not self.tool then | |
log:error("There is no tool configured for the Agent") | |
return cb({ status = "error" }) | |
end | |
local opts = self.tool.opts | |
if not opts then | |
log:error("There is no adapter configured for the `web_search` Tool") | |
return cb({ status = "error" }) | |
end | |
if not args then | |
log:error("There was no search query provided for the `web_search` Tool") | |
return cb({ status = "error" }) | |
end | |
args.query = string.gsub(args.query, "%f[%w_]web_search%f[^%w_]", "", 1) | |
local tool_adapter = config.strategies.chat.tools.web_search.opts.adapter | |
local adapter = adapters.resolve(config.adapters[tool_adapter]) | |
if not adapter then | |
log:error("Failed to load the adapter for the web_search Tool") | |
return cb({ status = "error" }) | |
end | |
client | |
.new({ | |
adapter = adapter, | |
}) | |
:request({ | |
url = adapter.url, | |
query = args.query, | |
}, { | |
callback = function(err, data) | |
if err then | |
log:error("Web Search Tool failed to fetch the URL, with error %s", err) | |
return cb({ status = "error", data = "Web Search Tool failed to fetch the URL, with error " .. err }) | |
end | |
if data then | |
local http_ok, body = pcall(vim.json.decode, data.body) | |
if not http_ok then | |
log:error("Web Search Tool Could not parse the JSON response") | |
return cb({ status = "error", data = "Web Search Tool Could not parse the JSON response" }) | |
end | |
if data.status == 200 then | |
local output = adapter.methods.tools.web_search.output(adapter, body) | |
return cb({ status = "success", data = output }) | |
else | |
log:error("Error %s - %s", data.status, body) | |
return cb({ status = "error", data = fmt("Web Search Tool Error %s - %s", data.status, body) }) | |
end | |
else | |
log:error("Error no data %s - %s", data.status) | |
return cb({ | |
status = "error", | |
data = fmt("Web Search Tool Error: No data received, status %s", data and data.status or "unknown"), | |
}) | |
end | |
end, | |
}) | |
end, | |
}, | |
schema = { | |
type = "function", | |
["function"] = { | |
name = "web_search", | |
description = "Search for recent information on the web", | |
parameters = { | |
type = "object", | |
properties = { | |
query = { | |
type = "string", | |
description = "Search query optimized for keyword searching.", | |
}, | |
}, | |
required = { "query" }, | |
additionalProperties = false, | |
}, | |
strict = true, | |
}, | |
}, | |
system_prompt = [[# Web Search Tool (`web_search`) | |
## CONTEXT | |
- You are connected to a Neovim instance via CodeCompanion. | |
- Using this tool you can search for recent information on the web. | |
- The user will allow this tool to be executed, so you do not need to ask for permission. | |
## OBJECTIVE | |
- Invoke this tool when up to date information is required. | |
## RESPONSE | |
- Return a single JSON-based function call matching the schema. | |
## POINTS TO NOTE | |
- This tool can be used alongside other tools within CodeCompanion. | |
- To make a web search, you can provide a search string optimized for keyword searching. | |
- Carefully craft your websearch to retrieve relevant and up to date information. | |
]], | |
output = { | |
---@param self CodeCompanion.Tool.Files | |
---@param agent CodeCompanion.Agent | |
---@param output string[][] -- The chat_output returned from the adapter will be in the first position in the table | |
success = function(self, agent, cmd, output) | |
local chat = agent.chat | |
local length = #output | |
local content = "" | |
if type(output[1]) == "table" then | |
content = table.concat(output[1], "") | |
length = #output[1] | |
end | |
local query_output = fmt([[**Web Search Tool**: Returned %d results for the query "%s"]], length, cmd.query) | |
chat:add_tool_output(self, content, query_output) | |
end, | |
---@param self CodeCompanion.Tool.Files | |
---@param agent CodeCompanion.Agent | |
---@param stderr table The error output from the command | |
error = function(self, agent, _, stderr, _) | |
local chat = agent.chat | |
local args = self.args | |
log:debug("[Web Search Tool] Error output: %s", stderr) | |
local error_output = | |
fmt([[**Web Search Tool**: There was an error for the following query: `%s`]], string.upper(args.query)) | |
chat:add_tool_output(self, error_output) | |
end, | |
}, | |
} | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/init.lua | |
--- | |
---@class CodeCompanion.Agent | |
---@field tools_config table The available tools for the agent | |
---@field aug number The augroup for the tool | |
---@field bufnr number The buffer of the chat buffer | |
---@field constants table<string, string> The constants for the tool | |
---@field chat CodeCompanion.Chat The chat buffer that initiated the tool | |
---@field extracted table The extracted tools from the LLM's response | |
---@field messages table The messages in the chat buffer | |
---@field status string The status of the tool | |
---@field stdout table The stdout of the tool | |
---@field stderr table The stderr of the tool | |
---@field tool CodeCompanion.Agent.Tool The current tool that's being run | |
---@field tools_ns integer The namespace for the virtual text that appears in the header | |
local Executor = require("codecompanion.strategies.chat.agents.executor") | |
local ToolFilter = require("codecompanion.strategies.chat.agents.tool_filter") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local regex = require("codecompanion.utils.regex") | |
local ui = require("codecompanion.utils.ui") | |
local util = require("codecompanion.utils") | |
local api = vim.api | |
local show_tools_processing = config.display.chat.show_tools_processing | |
local CONSTANTS = { | |
PREFIX = "@", | |
NS_TOOLS = "CodeCompanion-agents", | |
AUTOCMD_GROUP = "codecompanion.agent", | |
STATUS_ERROR = "error", | |
STATUS_SUCCESS = "success", | |
PROCESSING_MSG = config.display.icons.loading .. " Tools processing ...", | |
} | |
---@class CodeCompanion.Agent | |
local Agent = {} | |
---@param args table | |
function Agent.new(args) | |
local self = setmetatable({ | |
aug = api.nvim_create_augroup(CONSTANTS.AUTOCMD_GROUP .. ":" .. args.bufnr, { clear = true }), | |
bufnr = args.bufnr, | |
chat = {}, | |
constants = CONSTANTS, | |
extracted = {}, | |
messages = args.messages, | |
stdout = {}, | |
stderr = {}, | |
tool = {}, | |
tools_config = ToolFilter.filter_enabled_tools(config.strategies.chat.tools), -- Filter here | |
tools_ns = api.nvim_create_namespace(CONSTANTS.NS_TOOLS), | |
}, { __index = Agent }) | |
return self | |
end | |
---Set the autocmds for the tool | |
---@return nil | |
function Agent:set_autocmds() | |
api.nvim_create_autocmd("User", { | |
desc = "Handle responses from an Agent", | |
group = self.aug, | |
pattern = "CodeCompanionAgent*", | |
callback = function(request) | |
if request.data.bufnr ~= self.bufnr then | |
return | |
end | |
if request.match == "CodeCompanionAgentStarted" then | |
log:info("[Agent] Initiated") | |
if show_tools_processing then | |
local namespace = CONSTANTS.NS_TOOLS .. "_" .. tostring(self.bufnr) | |
ui.show_buffer_notification(self.bufnr, { | |
namespace = namespace, | |
text = CONSTANTS.PROCESSING_MSG, | |
main_hl = "CodeCompanionChatInfo", | |
spacer = true, | |
}) | |
end | |
elseif request.match == "CodeCompanionAgentFinished" then | |
return vim.schedule(function() | |
local auto_submit = function() | |
return self.chat:submit({ | |
auto_submit = true, | |
callback = function() | |
self:reset({ auto_submit = true }) | |
end, | |
}) | |
end | |
if vim.g.codecompanion_auto_tool_mode then | |
return auto_submit() | |
end | |
if self.status == CONSTANTS.STATUS_ERROR and self.tools_config.opts.auto_submit_errors then | |
return auto_submit() | |
end | |
if self.status == CONSTANTS.STATUS_SUCCESS and self.tools_config.opts.auto_submit_success then | |
return auto_submit() | |
end | |
self:reset({ auto_submit = false }) | |
end) | |
end | |
end, | |
}) | |
end | |
---Execute the tool in the chat buffer based on the LLM's response | |
---@param chat CodeCompanion.Chat | |
---@param tools table The tools requested by the LLM | |
---@return nil | |
function Agent:execute(chat, tools) | |
self.chat = chat | |
---Resolve and run the tool | |
---@param executor CodeCompanion.Agent.Executor The executor instance | |
---@param tool table The tool to run | |
local function enqueue_tool(executor, tool) | |
-- If an error occurred, don't run any more tools | |
if self.status == CONSTANTS.STATUS_ERROR then | |
return | |
end | |
local name = tool["function"].name | |
local tool_config = self.tools_config[name] | |
local function handle_missing_tool(tool_call, err_message) | |
tool_call.name = name | |
tool_call.function_call = tool_call | |
log:error(err_message) | |
local available_tools_msg = next(chat.tools.in_use or {}) | |
and "The available tools are: " .. table.concat( | |
vim.tbl_map(function(t) | |
return "`" .. t .. "`" | |
end, vim.tbl_keys(chat.tools.in_use)), | |
", " | |
) | |
or "No tools available" | |
self.chat:add_tool_output( | |
tool_call, | |
string.format("Tool `%s` not found. %s", name, available_tools_msg), | |
string.format("**%s Tool Error**: %s", name, err_message) | |
) | |
return util.fire("AgentFinished", { bufnr = self.bufnr }) | |
end | |
if not tool_config then | |
return handle_missing_tool(vim.deepcopy(tool), string.format("Couldn't find the tool `%s`", name)) | |
end | |
local ok, resolved_tool = pcall(function() | |
return Agent.resolve(tool_config) | |
end) | |
if not ok or not resolved_tool then | |
return handle_missing_tool(vim.deepcopy(tool), string.format("Couldn't resolve the tool `%s`", name)) | |
end | |
self.tool = vim.deepcopy(resolved_tool) | |
self.tool.name = name | |
self.tool.function_call = tool | |
if tool["function"].arguments then | |
local args = tool["function"].arguments | |
-- For some adapter's that aren't streaming, the args are strings rather than tables | |
if type(args) == "string" then | |
local decoded | |
xpcall(function() | |
decoded = vim.json.decode(args) | |
end, function(err) | |
log:error("Couldn't decode the tool arguments: %s", args) | |
self.chat:add_tool_output( | |
self.tool, | |
string.format('You made an error in calling the %s tool: "%s"', name, err), | |
string.format("**%s Tool Error**: %s", util.capitalize(name), err) | |
) | |
return util.fire("AgentFinished", { bufnr = self.bufnr }) | |
end) | |
args = decoded | |
end | |
self.tool.args = args | |
end | |
self.tool.opts = vim.tbl_extend("force", self.tool.opts or {}, tool_config.opts or {}) | |
if self.tool.env then | |
local env = type(self.tool.env) == "function" and self.tool.env(vim.deepcopy(self.tool)) or {} | |
util.replace_placeholders(self.tool.cmds, env) | |
end | |
return executor.queue:push(self.tool) | |
end | |
local id = math.random(10000000) | |
local executor = Executor.new(self, id) | |
for _, tool in ipairs(tools) do | |
enqueue_tool(executor, tool) | |
end | |
self:set_autocmds() | |
util.fire("AgentStarted", { id = id, bufnr = self.bufnr }) | |
xpcall(function() | |
executor:setup() | |
end, function(err) | |
log:error("Agent execution error:\n%s", err) | |
util.fire("AgentFinished", { id = id, bufnr = self.bufnr }) | |
end) | |
end | |
---Creates a regex pattern to match a tool name in a message | |
---@param tool string The tool name to create a pattern for | |
---@return string The compiled regex pattern | |
function Agent:_pattern(tool) | |
return CONSTANTS.PREFIX .. tool .. "\\(\\s\\|$\\)" | |
end | |
---Look for tools in a given message | |
---@param chat CodeCompanion.Chat | |
---@param message table | |
---@return table?, table? | |
function Agent:find(chat, message) | |
if not message.content then | |
return nil, nil | |
end | |
local groups = {} | |
local tools = {} | |
---@param tool string The tool name to search for | |
---@return number?,number? The start position of the match, or nil if not found | |
local function is_found(tool) | |
local pattern = self:_pattern(tool) | |
return regex.find(message.content, pattern) | |
end | |
-- Process groups | |
vim.iter(self.tools_config.groups):each(function(tool) | |
if is_found(tool) then | |
table.insert(groups, tool) | |
end | |
end) | |
-- Process tools | |
vim | |
.iter(self.tools_config) | |
:filter(function(name) | |
return name ~= "opts" and name ~= "groups" | |
end) | |
:each(function(tool) | |
if is_found(tool) and not vim.tbl_contains(tools, tool) then | |
table.insert(tools, tool) | |
end | |
end) | |
if #tools == 0 and #groups == 0 then | |
return nil, nil | |
end | |
return tools, groups | |
end | |
---Parse a user message looking for a tool | |
---@param chat CodeCompanion.Chat | |
---@param message table | |
---@return boolean | |
function Agent:parse(chat, message) | |
local tools, groups = self:find(chat, message) | |
if tools or groups then | |
if tools and not vim.tbl_isempty(tools) then | |
for _, tool in ipairs(tools) do | |
chat.tools:add(tool, self.tools_config[tool]) | |
end | |
end | |
if groups and not vim.tbl_isempty(groups) then | |
for _, group in ipairs(groups) do | |
chat.tools:add_group(group, self.tools_config) | |
end | |
end | |
return true | |
end | |
return false | |
end | |
---Replace the tool tag in a given message | |
---@param message string | |
---@return string | |
function Agent:replace(message) | |
for tool, _ in pairs(self.tools_config) do | |
if tool ~= "opts" and tool ~= "groups" then | |
message = vim.trim(regex.replace(message, self:_pattern(tool), tool)) | |
end | |
end | |
for group, _ in pairs(self.tools_config.groups) do | |
local tools = table.concat(self.tools_config.groups[group].tools, ", ") | |
message = vim.trim(regex.replace(message, self:_pattern(group), tools)) | |
end | |
return message | |
end | |
---Reset the Agent class | |
---@param opts? table | |
---@return nil | |
function Agent:reset(opts) | |
opts = opts or {} | |
if show_tools_processing then | |
ui.clear_notification(self.bufnr, { namespace = CONSTANTS.NS_TOOLS .. "_" .. tostring(self.bufnr) }) | |
end | |
api.nvim_clear_autocmds({ group = self.aug }) | |
self.extracted = {} | |
self.status = CONSTANTS.STATUS_SUCCESS | |
self.stderr = {} | |
self.stdout = {} | |
self.chat:tools_done(opts) | |
log:info("[Agent] Completed") | |
end | |
---Add an error message to the chat buffer | |
---@param error string | |
---@return CodeCompanion.Agent | |
function Agent:add_error_to_chat(error) | |
self.chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = error, | |
}, { visible = false }) | |
--- Alert the user that the error message has been shared | |
self.chat:add_buf_message({ | |
role = config.constants.USER_ROLE, | |
content = "Please correct for the error message I've shared", | |
}) | |
if self.tools_config.opts and self.tools_config.opts.auto_submit_errors then | |
self.chat:submit() | |
end | |
return self | |
end | |
---Resolve a tool from the config | |
---@param tool table The tool from the config | |
---@return CodeCompanion.Agent.Tool|nil | |
function Agent.resolve(tool) | |
local callback = tool.callback | |
if type(callback) == "table" then | |
return callback --[[@as CodeCompanion.Agent.Tool]] | |
end | |
if type(callback) == "function" then | |
return callback() --[[@as CodeCompanion.Agent.Tool]] | |
end | |
local ok, module = pcall(require, "codecompanion." .. callback) | |
if ok then | |
log:debug("[Tools] %s identified", callback) | |
return module | |
end | |
-- Try loading the tool from the user's config using a module path | |
ok, module = pcall(require, callback) | |
if ok then | |
log:debug("[Tools] %s identified", callback) | |
return module | |
end | |
-- Try loading the tool from the user's config using a file path | |
local err | |
module, err = loadfile(callback) | |
if err then | |
return error() | |
end | |
if module then | |
log:debug("[Tools] %s identified", callback) | |
return module() | |
end | |
end | |
return Agent | |
--- | |
File: /lua/codecompanion/strategies/chat/agents/tool_filter.lua | |
--- | |
local log = require("codecompanion.utils.log") | |
---@class CodeCompanion.Agent.ToolFilter | |
local ToolFilter = {} | |
local _enabled_cache = {} | |
local _cache_timestamp = 0 | |
local CACHE_TTL = 30000 | |
---Clear the enabled tools cache | |
---@return nil | |
local function clear_cache() | |
_enabled_cache = {} | |
_cache_timestamp = 0 | |
log:trace("[Tool Filter] Cache cleared") | |
end | |
---Check if the cache is valid | |
---@return boolean | |
local function is_cache_valid() | |
return vim.loop.now() - _cache_timestamp < CACHE_TTL | |
end | |
---Get enabled tools from the cache or compute them | |
---@param tools_config table The tools configuration | |
---@return table<string, boolean> Map of tool names to enabled status | |
local function get_enabled_tools(tools_config) | |
if is_cache_valid() and next(_enabled_cache) then | |
log:trace("[Tool Filter] Using cached enabled tools") | |
return _enabled_cache | |
end | |
log:trace("[Tool Filter] Computing enabled tools") | |
_enabled_cache = {} | |
_cache_timestamp = vim.loop.now() | |
for tool_name, tool_config in pairs(tools_config) do | |
-- Skip special keys | |
if tool_name ~= "opts" and tool_name ~= "groups" then | |
local is_enabled = true | |
if tool_config.enabled ~= nil then | |
if type(tool_config.enabled) == "function" then | |
local ok, result = pcall(tool_config.enabled) | |
if ok then | |
is_enabled = result | |
else | |
log:error("[Tool Filter] Error evaluating enabled function for tool '%s': %s", tool_name, result) | |
is_enabled = false | |
end | |
elseif type(tool_config.enabled) == "boolean" then | |
is_enabled = tool_config.enabled | |
end | |
end | |
_enabled_cache[tool_name] = is_enabled | |
log:trace("[Tool Filter] Tool '%s' enabled: %s", tool_name, is_enabled) | |
end | |
end | |
return _enabled_cache | |
end | |
---Filter tools configuration to only include enabled tools | |
---@param tools_config table The tools configuration | |
---@return table The filtered tools configuration | |
function ToolFilter.filter_enabled_tools(tools_config) | |
local enabled_tools = get_enabled_tools(tools_config) | |
local filtered_config = vim.deepcopy(tools_config) | |
-- Remove disabled tools | |
for tool_name, is_enabled in pairs(enabled_tools) do | |
if not is_enabled then | |
filtered_config[tool_name] = nil | |
log:trace("[Tool Filter] Filtered out disabled tool: %s", tool_name) | |
end | |
end | |
-- Filter groups to only include enabled tools | |
if filtered_config.groups then | |
for group_name, group_config in pairs(filtered_config.groups) do | |
if group_config.tools then | |
local enabled_group_tools = {} | |
for _, tool_name in ipairs(group_config.tools) do | |
if enabled_tools[tool_name] then | |
table.insert(enabled_group_tools, tool_name) | |
end | |
end | |
filtered_config.groups[group_name].tools = enabled_group_tools | |
-- Remove group if no tools are enabled | |
if #enabled_group_tools == 0 then | |
filtered_config.groups[group_name] = nil | |
log:trace("[Tool Filter] Filtered out group with no enabled tools: %s", group_name) | |
end | |
end | |
end | |
end | |
return filtered_config | |
end | |
---Check if a specific tool is enabled | |
---@param tool_name string The name of the tool | |
---@param tools_config table The tools configuration | |
---@return boolean | |
function ToolFilter.is_tool_enabled(tool_name, tools_config) | |
local enabled_tools = get_enabled_tools(tools_config) | |
return enabled_tools[tool_name] == true | |
end | |
---Force the cache to refresh (useful for testing or manual refresh) | |
---@return nil | |
function ToolFilter.refresh_cache() | |
clear_cache() | |
log:trace("[Tool Filter] Cache manually refreshed") | |
end | |
vim.api.nvim_create_autocmd("User", { | |
pattern = "CodeCompanionChatRefreshCache", | |
callback = function() | |
log:trace("[Tool Filter] Cache cleared via autocommand") | |
clear_cache() | |
end, | |
}) | |
return ToolFilter | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/buffer.lua | |
--- | |
local buf = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local util = require("codecompanion.utils") | |
local fmt = string.format | |
local CONSTANTS = { | |
NAME = "Buffer", | |
PROMPT = "Select buffer(s)", | |
} | |
local providers = { | |
---The default provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
default = function(SlashCommand) | |
local default = require("codecompanion.providers.slash_commands.default") | |
default = default | |
.new({ | |
output = function(selection) | |
SlashCommand:output(selection) | |
end, | |
SlashCommand = SlashCommand, | |
title = CONSTANTS.PROMPT, | |
}) | |
:buffers() | |
:display() | |
end, | |
---The Snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
title = CONSTANTS.PROMPT .. ": ", | |
output = function(selection) | |
return SlashCommand:output({ | |
bufnr = selection.buf, | |
name = vim.fn.bufname(selection.buf), | |
path = selection.file, | |
}) | |
end, | |
}) | |
snacks.provider.picker.pick({ | |
source = "buffers", | |
prompt = snacks.title, | |
confirm = snacks:display(), | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selection) | |
return SlashCommand:output({ | |
bufnr = selection.bufnr, | |
name = selection.filename, | |
path = selection.path, | |
}) | |
end, | |
}) | |
telescope.provider.buffers({ | |
prompt_title = telescope.title, | |
ignore_current_buffer = true, -- Ignore the codecompanion buffer when selecting buffers | |
attach_mappings = telescope:display(), | |
}) | |
end, | |
---The Mini.Pick provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
mini_pick = function(SlashCommand) | |
local mini_pick = require("codecompanion.providers.slash_commands.mini_pick") | |
mini_pick = mini_pick.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
mini_pick.provider.builtin.buffers( | |
{ include_current = false }, | |
mini_pick:display(function(selected) | |
return { | |
bufnr = selected.bufnr, | |
name = selected.text, | |
path = selected.text, | |
} | |
end) | |
) | |
end, | |
---The fzf-lua provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
fzf_lua = function(SlashCommand) | |
local fzf = require("codecompanion.providers.slash_commands.fzf_lua") | |
fzf = fzf.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
fzf.provider.buffers(fzf:display(function(selected, opts) | |
local file = fzf.provider.path.entry_to_file(selected, opts) | |
return { | |
bufnr = file.bufnr, | |
name = file.path, | |
path = file.bufname, | |
} | |
end)) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.Buffer: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
function SlashCommand:execute(SlashCommands) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
return SlashCommands:set_provider(self, providers) | |
end | |
---Output from the slash command in the chat buffer | |
---@param selected table The selected item from the provider { relative_path = string, path = string } | |
---@param opts? table | |
---@return nil | |
function SlashCommand:output(selected, opts) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
opts = opts or {} | |
local message = "Here is the content from a file (including line numbers)" | |
if opts.pin then | |
message = "Here is the updated content from a file (including line numbers)" | |
end | |
local ok, content, id, filename = pcall(buf.format_for_llm, selected, { message = message }) | |
if not ok then | |
return log:warn(content) | |
end | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = content, | |
}, { reference = id, visible = false }) | |
if opts.pin then | |
return | |
end | |
local slash_command_opts = self.config.opts and self.config.opts.default_params or nil | |
if slash_command_opts then | |
if slash_command_opts == "pin" then | |
opts.pinned = true | |
elseif slash_command_opts == "watch" then | |
opts.watched = true | |
end | |
end | |
self.Chat.references:add({ | |
bufnr = selected.bufnr, | |
id = id, | |
path = selected.path, | |
opts = opts, | |
source = "codecompanion.strategies.chat.slash_commands.buffer", | |
}) | |
util.notify(fmt("Added buffer `%s` to the chat", filename)) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/fetch.lua | |
--- | |
local Path = require("plenary.path") | |
local adapters = require("codecompanion.adapters") | |
local client = require("codecompanion.http") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local util = require("codecompanion.utils") | |
local util_hash = require("codecompanion.utils.hash") | |
local fmt = string.format | |
local CONSTANTS = { | |
NAME = "Fetch", | |
CACHE_PATH = config.strategies.chat.slash_commands.fetch.opts.cache_path, | |
} | |
---Get the cached URLs from the directory | |
---@return table | |
local function get_cached_files() | |
local scan = require("plenary.scandir") | |
local cache_dir = Path:new(CONSTANTS.CACHE_PATH):expand() | |
if not Path:new(cache_dir):exists() then | |
return {} | |
end | |
local cache = scan.scan_dir(cache_dir, { | |
depth = 1, | |
search_pattern = "%.json$", | |
}) | |
local urls = vim | |
.iter(cache) | |
:map(function(f) | |
local file = Path:new(f):read() | |
local content = vim.json.decode(file) | |
return { | |
filepath = f, | |
content = content.data, | |
filename = vim.fn.fnamemodify(f, ":t"), | |
url = content.url, | |
timestamp = content.timestamp, | |
display = string.format("[%s] %s", util.make_relative(content.timestamp), content.url), | |
} | |
end) | |
:totable() | |
-- Sort by timestamp (newest first) | |
table.sort(urls, function(a, b) | |
return a.timestamp > b.timestamp | |
end) | |
return urls | |
end | |
---Format the output for the chat buffer | |
---@param url string | |
---@param text string | |
---@param opts table | |
---@return string | |
local function format_output(url, text, opts) | |
local output = [[%s | |
<content> | |
%s | |
</content>]] | |
if opts and opts.description then | |
return fmt(output, opts.description, text) | |
end | |
return fmt(output, "Here is the output from " .. url .. " that I'm sharing with you:", text) | |
end | |
---Output the contents of the URL to the chat buffer @param chat CodeCompanion.Chat | |
---@param data table | |
---@param opts? table | |
---@return nil | |
local function output(chat, data, opts) | |
opts = opts or {} | |
local id = "<url>" .. data.url .. "</url>" | |
chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = format_output(data.url, data.content, opts), | |
}, { reference = id, visible = false }) | |
chat.references:add({ | |
source = "slash_command", | |
name = "fetch", | |
id = id, | |
}) | |
if opts.silent then | |
return | |
end | |
return util.notify(fmt("Added `%s` to the chat", data.url)) | |
end | |
local providers = { | |
---The default provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
default = function(SlashCommand) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
local default = require("codecompanion.providers.slash_commands.default") | |
return default | |
.new({ | |
output = function(selection) | |
return output(SlashCommand.Chat, selection) | |
end, | |
SlashCommand = SlashCommand, | |
}) | |
:urls(cached_files) | |
:display() | |
end, | |
---The snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
output = function(selection) | |
return output(SlashCommand.Chat, selection) | |
end, | |
}) | |
-- Transform cached files into picker items | |
local items = vim.tbl_map(function(file) | |
return { | |
text = file.display, | |
file = file.filepath, | |
url = file.url, | |
content = file.content, | |
timestamp = file.timestamp, | |
} | |
end, cached_files) | |
snacks.provider.picker.pick({ | |
title = "Cached URLs", | |
items = items, | |
prompt = snacks.title, | |
format = function(item, _) | |
local display_text = item.text | |
return { { display_text } } | |
end, | |
confirm = snacks:display(), | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
output = function(selection) | |
return output(SlashCommand.Chat, selection) | |
end, | |
}) | |
local pickers = require("telescope.pickers") | |
local finders = require("telescope.finders") | |
local function create_finder() | |
return finders.new_table({ | |
results = cached_files, | |
entry_maker = function(entry) | |
return { | |
value = entry, | |
content = entry.content, | |
url = entry.url, | |
ordinal = entry.display, | |
display = entry.display, | |
filename = entry.filepath, | |
} | |
end, | |
}) | |
end | |
pickers | |
.new({ | |
finder = create_finder(), | |
attach_mappings = telescope:display(), | |
}) | |
:find() | |
end, | |
---The Mini.Pick provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
mini_pick = function(SlashCommand) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
local mini_pick = require("codecompanion.providers.slash_commands.mini_pick") | |
mini_pick = mini_pick.new({ | |
output = function(selected) | |
return output(SlashCommand.Chat, selected) | |
end, | |
}) | |
local items = vim.tbl_map(function(file) | |
return { | |
text = file.display, | |
url = file.url, | |
content = file.content, | |
} | |
end, cached_files) | |
mini_pick.provider.start({ | |
source = vim.tbl_deep_extend( | |
"force", | |
mini_pick:display(function(picked_item) | |
return picked_item | |
end).source, | |
{ | |
items = items, | |
} | |
), | |
}) | |
end, | |
---The FZF-Lua provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
fzf_lua = function(SlashCommand) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
local fzf = require("codecompanion.providers.slash_commands.fzf_lua") | |
fzf = fzf.new({ | |
output = function(selected) | |
return output(SlashCommand.Chat, selected) | |
end, | |
}) | |
local items = vim.tbl_map(function(file) | |
return file.display | |
end, cached_files) | |
local transformer_fn = function(selected, _) | |
for _, file_object in ipairs(cached_files) do | |
if file_object.display == selected then | |
return file_object | |
end | |
end | |
end | |
fzf.provider.fzf_exec(items, fzf:display(transformer_fn)) | |
end, | |
} | |
---Determine if the URL has already been cached | |
---@param hash string | |
---@return boolean | |
local function is_cached(hash) | |
local p = Path:new(CONSTANTS.CACHE_PATH .. "/" .. hash) | |
return p:exists() | |
end | |
---Read the cache for the URL | |
---@param chat CodeCompanion.Chat | |
---@param url string | |
---@param hash string | |
---@param opts table | |
---@return nil | |
local function read_cache(chat, url, hash, opts) | |
local p = Path:new(CONSTANTS.CACHE_PATH .. "/" .. hash) | |
local cache = p:read() | |
log:debug("Fetch Slash Command: Restoring from cache for %s", url) | |
return output(chat, { | |
content = cache, | |
url = url, | |
}, opts) | |
end | |
---Write the cache for the URL | |
---@param hash string | |
---@param data string | |
---@return nil | |
local function write_cache(hash, data) | |
local p = Path:new(CONSTANTS.CACHE_PATH .. "/" .. hash .. ".json") | |
p.filename = p:expand() | |
vim.fn.mkdir(CONSTANTS.CACHE_PATH, "p") | |
p:touch({ parents = true }) | |
p:write(data or "", "w") | |
end | |
---Fetch the contents of a URL | |
---@param chat CodeCompanion.Chat | |
---@param adapter table | |
---@param url string | |
---@param opts table | |
---@return nil | |
local function fetch(chat, adapter, url, opts) | |
log:debug("Fetch Slash Command: Fetching from %s", url) | |
-- Make sure that we don't modify the original adapter | |
adapter = vim.deepcopy(adapter) | |
adapter.methods.slash_commands.fetch(adapter) | |
return client | |
.new({ | |
adapter = adapter, | |
}) | |
:request({ | |
url = url, | |
}, { | |
callback = function(err, data) | |
if err then | |
return log:error("Failed to fetch the URL, with error %s", err) | |
end | |
if data then | |
local ok, body = pcall(vim.json.decode, data.body) | |
if not ok then | |
return log:error("Could not parse the JSON response") | |
end | |
if data.status == 200 then | |
output(chat, { | |
content = body.data.text, | |
url = url, | |
}, opts) | |
-- Cache the response | |
-- TODO: Get an LLM to create summary | |
vim.ui.select({ "Yes", "No" }, { | |
prompt = "Do you want to cache this URL?", | |
kind = "codecompanion.nvim", | |
}, function(selected) | |
if selected == "Yes" then | |
local hash = util_hash.hash(url) | |
write_cache( | |
hash, | |
vim.json.encode({ | |
url = url, | |
hash = hash, | |
timestamp = os.time(), | |
data = body.data.text, | |
}) | |
) | |
end | |
end) | |
else | |
return log:error("Error %s - %s", data.status, body.message or "No message provided") | |
end | |
end | |
end, | |
}) | |
end | |
-- The different choices to load URLs in to the chat buffer | |
local choice = { | |
URL = function(SlashCommand, _) | |
return vim.ui.input({ prompt = "Enter the URL: " }, function(url) | |
if #vim.trim(url or "") == 0 then | |
return | |
end | |
return SlashCommand:output(url) | |
end) | |
end, | |
Cache = function(SlashCommand, _) | |
local cached_files = get_cached_files() | |
if #cached_files == 0 then | |
return util.notify("No cached URLs found", vim.log.levels.WARN) | |
end | |
return providers[SlashCommand.config.opts.provider](SlashCommand, cached_files) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.Fetch: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@param opts? table | |
---@return nil|string | |
function SlashCommand:execute(SlashCommands, opts) | |
local cached_files = get_cached_files() | |
local options = { "URL" } | |
if #cached_files > 0 then | |
table.insert(options, "Cache") | |
end | |
if #options == 1 then | |
return choice[options[1]](self, SlashCommands) | |
end | |
vim.ui.select(options, { | |
prompt = "Select link source", | |
kind = "codecompanion.nvim", | |
}, function(selected) | |
if not selected then | |
return | |
end | |
return choice[selected](self, SlashCommands) | |
end) | |
end | |
---Output the contents of the URL | |
---@param url string | |
---@param opts? table | |
---@return nil | |
function SlashCommand:output(url, opts) | |
opts = opts or {} | |
local adapter = adapters.get_from_string(self.config.opts.adapter) | |
if not adapter then | |
return log:error("Could not resolve adapter for the fetch slash command") | |
end | |
local function call_fetch() | |
return fetch(self.Chat, adapter, url, opts) | |
end | |
local hash = util_hash.hash(url) | |
if opts and opts.ignore_cache then | |
log:debug("Fetch Slash Command: Ignoring cache") | |
return call_fetch() | |
end | |
if opts and opts.auto_restore_cache and is_cached(hash) then | |
log:debug("Fetch Slash Command: Auto restoring from cache") | |
return read_cache(self.Chat, url, hash, opts) | |
end | |
return call_fetch() | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/file.lua | |
--- | |
local path = require("plenary.path") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local util = require("codecompanion.utils") | |
local fmt = string.format | |
local CONSTANTS = { | |
NAME = "File", | |
PROMPT = "Select file(s)", | |
} | |
local providers = { | |
---The default provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
default = function(SlashCommand) | |
local default = require("codecompanion.providers.slash_commands.default") | |
return default | |
.new({ | |
output = function(selection) | |
return SlashCommand:output(selection) | |
end, | |
SlashCommand = SlashCommand, | |
title = CONSTANTS.PROMPT, | |
}) | |
:find_files() | |
:display() | |
end, | |
---The Snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
title = CONSTANTS.PROMPT .. ": ", | |
output = function(selection) | |
return SlashCommand:output({ | |
relative_path = selection.file, | |
path = vim.fs.joinpath(selection.cwd, selection.file), | |
}) | |
end, | |
}) | |
snacks.provider.picker.pick({ | |
source = "files", | |
prompt = snacks.title, | |
confirm = snacks:display(), | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selection) | |
return SlashCommand:output(selection) | |
end, | |
}) | |
telescope.provider.find_files({ | |
prompt_title = telescope.title, | |
attach_mappings = telescope:display(), | |
hidden = true, | |
}) | |
end, | |
---The Mini.Pick provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
mini_pick = function(SlashCommand) | |
local mini_pick = require("codecompanion.providers.slash_commands.mini_pick") | |
mini_pick = mini_pick.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
mini_pick.provider.builtin.files( | |
{}, | |
mini_pick:display(function(selected) | |
return { | |
path = selected, | |
} | |
end) | |
) | |
end, | |
---The fzf-lua provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
fzf_lua = function(SlashCommand) | |
local fzf = require("codecompanion.providers.slash_commands.fzf_lua") | |
fzf = fzf.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
fzf.provider.files(fzf:display(function(selected, opts) | |
local file = fzf.provider.path.entry_to_file(selected, opts) | |
return { | |
relative_path = file.stripped, | |
path = file.path, | |
} | |
end)) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.File: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
opts = args.opts, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
function SlashCommand:execute(SlashCommands) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
return SlashCommands:set_provider(self, providers) | |
end | |
---Open and read the contents of the selected file | |
---@param selected { path: string, relative_path: string?, description: string? } | |
function SlashCommand:read(selected) | |
local ok, content = pcall(function() | |
return path.new(selected.path):read() | |
end) | |
if not ok then | |
return "" | |
end | |
local ft = vim.filetype.match({ filename = selected.path }) | |
local relative_path = vim.fn.fnamemodify(selected.path, ":.") | |
local id = "<file>" .. relative_path .. "</file>" | |
return content, ft, id, relative_path | |
end | |
---Output from the slash command in the chat buffer | |
---@param selected { relative_path: string?, path: string, description: string? } | |
---@param opts? { silent: boolean, pin: boolean } | |
---@return nil | |
function SlashCommand:output(selected, opts) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
opts = opts or {} | |
local content, ft, id, relative_path = self:read(selected) | |
if content == "" then | |
return log:warn("Could not read the file: %s", selected.path) | |
end | |
-- Workspaces allow the user to set their own custom description which should take priority | |
local description | |
if selected.description then | |
description = fmt( | |
[[%s | |
```%s | |
%s | |
```]], | |
selected.description, | |
ft, | |
content | |
) | |
else | |
description = fmt( | |
[[<attachment filepath="%s">%s: | |
```%s | |
%s | |
``` | |
</attachment>]], | |
relative_path, | |
opts.pin and "Here is the updated content from the file" or "Here is the content from the file", | |
ft, | |
content | |
) | |
end | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = description or "", | |
}, { reference = id, visible = false }) | |
if opts.pin then | |
return | |
end | |
self.Chat.references:add({ | |
id = id or "", | |
path = selected.path, | |
source = "codecompanion.strategies.chat.slash_commands.file", | |
}) | |
if opts.silent then | |
return | |
end | |
util.notify(fmt("Added the `%s` file to the chat", vim.fn.fnamemodify(relative_path, ":t"))) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/help.lua | |
--- | |
local path = require("plenary.path") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local util = require("codecompanion.utils") | |
local ts = vim.treesitter | |
local line_count = 0 | |
local CONSTANTS = { | |
NAME = "Help", | |
PROMPT = "Select a help tag", | |
MAX_LINES = config.strategies.chat.slash_commands.help.opts.max_lines, | |
} | |
---Find the tag row | |
---@param tag string The tag to find | |
---@param content string The content of the file | |
---@return integer The row of the tag | |
local function get_tag_row(tag, content) | |
local ft = "vimdoc" | |
local parser = vim.treesitter.get_string_parser(content, "vimdoc") | |
local root = parser:parse()[1]:root() | |
local query = ts.query.parse(ft, '((tag) @tag (#eq? @tag "*' .. tag .. '*"))') | |
for _, node, _ in query:iter_captures(root, content) do | |
local tag_row = node:range() | |
return tag_row | |
end | |
end | |
---Trim the content around the tag | |
---@param content string The content of the file | |
---@param tag string The tag to find | |
---@return string The trimmed content | |
local function trim_content(content, tag) | |
local lines = vim.split(content, "\n") | |
local tag_row = get_tag_row(tag, content) | |
local prefix = "" | |
local suffix = "" | |
local start_, end_ | |
if tag_row - CONSTANTS.MAX_LINES / 2 < 1 then | |
start_ = 1 | |
end_ = CONSTANTS.MAX_LINES | |
suffix = "\n..." | |
elseif tag_row + CONSTANTS.MAX_LINES / 2 > #lines then | |
start_ = #lines - CONSTANTS.MAX_LINES | |
end_ = #lines | |
prefix = "...\n" | |
else | |
start_ = tag_row - CONSTANTS.MAX_LINES / 2 | |
end_ = tag_row + CONSTANTS.MAX_LINES / 2 | |
prefix = "...\n" | |
suffix = "\n..." | |
end | |
content = table.concat(vim.list_slice(lines, start_, end_), "\n") | |
return prefix .. content .. suffix | |
end | |
---Send the output to the chat buffer | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@param content string The content of the help file | |
---@param selected table The selected item from the provider { tag = string, path = string } | |
---@return nil | |
local function send_output(SlashCommand, content, selected) | |
local ft = "vimdoc" | |
local Chat = SlashCommand.Chat | |
local id = "<help>" .. selected.tag .. "</help>" | |
Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = string.format( | |
[[Help context for `%s`: | |
```%s | |
%s | |
``` | |
Note the path to the help file is `%s`. | |
]], | |
selected.tag, | |
ft, | |
content, | |
selected.path | |
), | |
}, { reference = id, visible = false }) | |
Chat.references:add({ | |
source = "slash_command", | |
name = "help", | |
id = id, | |
}) | |
return util.notify(string.format("Added the `%s` help to the chat", selected.tag)) | |
end | |
---Output from the slash command in the chat buffer | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@param selected table The selected item from the provider { tag = string, path = string } | |
---@return nil | |
local function output(SlashCommand, selected) | |
if not config.can_send_code() and (SlashCommand.config.opts and SlashCommand.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
local content = path.new(selected.path):read() | |
line_count = #vim.split(content, "\n") | |
if line_count > CONSTANTS.MAX_LINES then | |
vim.ui.select({ "Yes", "No" }, { | |
kind = "codecompanion.nvim", | |
prompt = "The help file is more than " .. CONSTANTS.MAX_LINES .. " lines. Do you want to trim it?", | |
}, function(choice) | |
if not choice then | |
return | |
end | |
if choice == "No" then | |
return send_output(SlashCommand, content, selected) | |
end | |
content = trim_content(content, selected.tag) | |
return send_output(SlashCommand, content, selected) | |
end) | |
else | |
return send_output(SlashCommand, content, selected) | |
end | |
end | |
local providers = { | |
---The Snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
title = CONSTANTS.PROMPT .. ": ", | |
output = function(selection) | |
return output(SlashCommand, { | |
path = selection.file, | |
tag = selection.tag, | |
}) | |
end, | |
}) | |
snacks.provider.picker.pick({ | |
source = "help", | |
prompt = snacks.title, | |
confirm = snacks:display(), | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selection) | |
return output(SlashCommand, { | |
path = selection.filename, | |
tag = selection.display, | |
}) | |
end, | |
}) | |
telescope.provider.help_tags({ | |
prompt_title = telescope.title, | |
attach_mappings = telescope:display(), | |
}) | |
end, | |
---The Mini.Pick provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
mini_pick = function(SlashCommand) | |
local mini_pick = require("codecompanion.providers.slash_commands.mini_pick") | |
mini_pick = mini_pick.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return output(SlashCommand, selected) | |
end, | |
}) | |
mini_pick.provider.builtin.help( | |
{}, | |
mini_pick:display(function(selected) | |
return { | |
path = selected.filename, | |
tag = selected.name, | |
} | |
end) | |
) | |
end, | |
---The fzf-lua provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
fzf_lua = function(SlashCommand) | |
local fzf = require("codecompanion.providers.slash_commands.fzf_lua") | |
fzf = fzf.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return output(SlashCommand, selected) | |
end, | |
}) | |
fzf.provider.helptags(fzf:display(function(selected, opts) | |
local file = fzf.provider.path.entry_to_file(selected, opts) | |
return { | |
path = file.path, | |
tag = selected:match("[^%s]+"), | |
} | |
end)) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.Help: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
function SlashCommand:execute(SlashCommands) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
return SlashCommands:set_provider(self, providers) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/image.lua | |
--- | |
local Curl = require("plenary.curl") | |
local config = require("codecompanion.config") | |
local helpers = require("codecompanion.strategies.chat.helpers") | |
local log = require("codecompanion.utils.log") | |
local CONSTANTS = { | |
NAME = "Image", | |
PROMPT = "Images", | |
IMAGE_DIRS = config.strategies.chat.slash_commands.image.opts.dirs, | |
IMAGE_TYPES = config.strategies.chat.slash_commands.image.opts.filetypes, | |
} | |
---Prepares image search directories and filetypes for a single invocation. | |
---@return table, table|nil Returns search_dirs and filetypes | |
local function prepare_image_search_options() | |
local current_search_dirs = { vim.fn.getcwd() } -- Start with CWD for this call | |
if CONSTANTS.IMAGE_DIRS and vim.tbl_count(CONSTANTS.IMAGE_DIRS) > 0 then | |
vim.list_extend(current_search_dirs, CONSTANTS.IMAGE_DIRS) | |
end | |
local ft = nil | |
if CONSTANTS.IMAGE_TYPES and vim.tbl_count(CONSTANTS.IMAGE_TYPES) > 0 then | |
ft = CONSTANTS.IMAGE_TYPES | |
end | |
return current_search_dirs, ft | |
end | |
local providers = { | |
---The default provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
default = function(SlashCommand) | |
local dirs, ft = prepare_image_search_options() | |
local default = require("codecompanion.providers.slash_commands.default") | |
default = default | |
.new({ | |
output = function(selection) | |
SlashCommand:output(selection) | |
end, | |
SlashCommand = SlashCommand, | |
title = CONSTANTS.PROMPT, | |
}) | |
:images(dirs, ft) | |
:display() | |
end, | |
---The Snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
output = function(selection) | |
return SlashCommand:output({ | |
relative_path = selection.file, | |
path = selection.file, | |
}) | |
end, | |
}) | |
local dirs, ft = prepare_image_search_options() | |
snacks.provider.picker.pick("files", { | |
confirm = snacks:display(), | |
dirs = dirs, | |
ft = ft, | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selection) | |
return SlashCommand:output({ | |
path = selection[1], | |
}) | |
end, | |
}) | |
local dirs, img_fts = prepare_image_search_options() | |
local find_command = { "fd", "--type", "f", "--follow", "--hidden" } | |
for _, ext in ipairs(img_fts) do | |
table.insert(find_command, "--extension") | |
table.insert(find_command, ext) | |
end | |
telescope.provider.find_files({ | |
find_command = find_command, | |
prompt_title = telescope.title, | |
attach_mappings = telescope:display(), | |
search_dirs = dirs, | |
}) | |
end, | |
} | |
-- The different choices the user has to insert an image via a slash command | |
local choice = { | |
---Load the file picker | |
---@param SlashCommand CodeCompanion.SlashCommand.Image | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
File = function(SlashCommand, SlashCommands) | |
return SlashCommands:set_provider(SlashCommand, providers) | |
end, | |
---Share the URL of an image | |
---@param SlashCommand CodeCompanion.SlashCommand.Image | |
---@return nil | |
URL = function(SlashCommand, _) | |
return vim.ui.input({ prompt = "Enter the URL: " }, function(url) | |
if #vim.trim(url or "") == 0 then | |
return | |
end | |
if vim.fn.executable("base64") == 0 then | |
return log:warn("The `base64` command could not be found") | |
end | |
-- Download the image to a temporary directory | |
local loc = vim.fn.tempname() | |
local response | |
local curl_ok, curl_payload = pcall(function() | |
response = Curl.get(url, { | |
insecure = config.adapters.opts.allow_insecure, | |
proxy = config.adapters.opts.proxy, | |
output = loc, | |
}) | |
end) | |
if not curl_ok then | |
vim.loop.fs_unlink(loc) | |
return log:error("Failed to execute curl: %s", tostring(curl_payload)) | |
end | |
-- Check if the response is valid | |
if not response or (response.status and response.status >= 400) then | |
local err_msg = "Could not download the image." | |
if response and response.status then | |
err_msg = err_msg .. " HTTP Status: " .. response.status | |
end | |
if response and response.body and #response.body > 0 then | |
err_msg = err_msg .. "\nServer response: " .. response.body:sub(1, 200) | |
end | |
vim.loop.fs_unlink(loc) | |
return log:error(err_msg) | |
end | |
-- Fetch the MIME type from headers | |
local mimetype = nil | |
if response.headers then | |
for _, header_line in ipairs(response.headers) do | |
local key, value = header_line:match("^([^:]+):%s*(.+)$") | |
if key and value and key:lower() == "content-type" then | |
mimetype = vim.trim(value:match("^([^;]+)")) -- Get part before any '; charset=...' | |
break | |
end | |
end | |
end | |
return SlashCommand:output({ | |
id = url, | |
path = loc, | |
mimetype = mimetype, | |
}) | |
end) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.Image: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
function SlashCommand:execute(SlashCommands) | |
vim.ui.select({ "URL", "File" }, { | |
prompt = "Select an image source", | |
}, function(selected) | |
if not selected then | |
return | |
end | |
return choice[selected](self, SlashCommands) | |
end) | |
end | |
---Put a reference to the image in the chat buffer | |
---@param selected table The selected image { source = string, path = string } | |
---@param opts? table | |
---@return nil | |
function SlashCommand:output(selected, opts) | |
local encoded_image = helpers.encode_image(selected) | |
if type(encoded_image) == "string" then | |
return log:error("Could not encode image: %s", encoded_image) | |
end | |
return helpers.add_image(self.Chat, selected) | |
end | |
---Is the slash command enabled? | |
---@param chat CodeCompanion.Chat | |
---@return boolean,string | |
function SlashCommand.enabled(chat) | |
return chat.adapter.opts.vision, "The image Slash Command is not enabled for this adapter" | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/init.lua | |
--- | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
---Resolve the callback to the correct module | |
---@param callback string The module to get | |
---@return table|nil | |
local function resolve(callback) | |
local ok, slash_command = pcall(require, "codecompanion." .. callback) | |
if ok then | |
log:debug("Calling slash command: %s", callback) | |
return slash_command | |
end | |
-- Try loading the tool from the user's config | |
local err | |
slash_command, err = loadfile(callback) | |
if err then | |
return log:error("Could not load the slash command: %s", callback) | |
end | |
if slash_command then | |
log:debug("Calling slash command: %s", callback) | |
return slash_command() | |
end | |
end | |
---@class CodeCompanion.SlashCommands | |
local SlashCommands = {} | |
---@class CodeCompanion.SlashCommands | |
function SlashCommands.new() | |
return setmetatable({}, { __index = SlashCommands }) | |
end | |
---Set the provider to use for the Slash Command | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@param providers table | |
---@return function | |
function SlashCommands:set_provider(SlashCommand, providers) | |
if SlashCommand.config.opts and SlashCommand.config.opts.provider then | |
if not providers[SlashCommand.config.opts.provider] then | |
return log:error( | |
"Provider for the symbols slash command could not be found: %s", | |
SlashCommand.config.opts.provider | |
) | |
end | |
return providers[SlashCommand.config.opts.provider](SlashCommand) --[[@type function]] | |
end | |
return providers["default"] --[[@type function]] | |
end | |
---Execute the selected slash command | |
---@param item table The selected item from the completion menu | |
---@param chat CodeCompanion.Chat | |
---@return nil | |
function SlashCommands:execute(item, chat) | |
local label = item.label:sub(1) | |
log:debug("Executing slash command: %s", label) | |
-- If the user has provided a callback function, use that | |
if type(item.config.callback) == "function" then | |
return item.config.callback(chat) | |
end | |
local callback = resolve(item.config.callback) | |
if not callback then | |
return log:error("Slash command not found: %s", label) | |
end | |
if callback.enabled then | |
local enabled, err = callback.enabled(chat) | |
if enabled == false then | |
return log:warn(err) | |
end | |
end | |
return callback | |
.new({ | |
Chat = chat, | |
config = item.config, | |
context = item.context, | |
}) | |
:execute(self) | |
end | |
---Function for external objects to add references via Slash Commands | |
---@param chat CodeCompanion.Chat | |
---@param slash_command string | |
---@param opts { path: string, url?: string, description: string, [any]: any } | |
---@return nil | |
function SlashCommands.references(chat, slash_command, opts) | |
local slash_commands = { | |
file = require("codecompanion.strategies.chat.slash_commands.file").new({ | |
Chat = chat, | |
}), | |
symbols = require("codecompanion.strategies.chat.slash_commands.symbols").new({ | |
Chat = chat, | |
}), | |
url = require("codecompanion.strategies.chat.slash_commands.fetch").new({ | |
Chat = chat, | |
config = config.strategies.chat.slash_commands["fetch"], | |
}), | |
} | |
if slash_command == "file" or slash_command == "symbols" then | |
return slash_commands[slash_command]:output({ description = opts.description, path = opts.path }, { silent = true }) | |
end | |
if slash_command == "url" then | |
-- NOTE: To conform to the <path, description> interface, we need to pass all | |
-- other options via the opts table. Then, of course, we need to strip the | |
-- double opts out of the opts table. Hacky, for sure. | |
opts.silent = true | |
opts.url = opts.url or opts.path | |
opts.description = opts.description | |
opts.auto_restore_cache = opts.opts.auto_restore_cache | |
opts.ignore_cache = opts.opts.ignore_cache | |
return slash_commands[slash_command]:output(opts.url, opts) | |
end | |
end | |
return SlashCommands | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/keymaps.lua | |
--- | |
local config = require("codecompanion.config") | |
local slash_commands = require("codecompanion.strategies.chat.slash_commands") | |
local M = {} | |
for name, cmd in pairs(config.strategies.chat.slash_commands) do | |
M[name] = { | |
callback = function(chat) | |
return slash_commands.new():execute({ label = name, config = cmd }, chat) | |
end, | |
} | |
end | |
return M | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/now.lua | |
--- | |
---@class CodeCompanion.SlashCommand.Now: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommand | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@return nil | |
function SlashCommand:execute() | |
local Chat = self.Chat | |
Chat:add_buf_message({ content = os.date("%a, %d %b %Y %H:%M:%S %z") }) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/quickfix.lua | |
--- | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local path = require("plenary.path") | |
local symbol_utils = require("codecompanion.strategies.chat.helpers") | |
local util = require("codecompanion.utils") | |
local fmt = string.format | |
---Get quickfix list with type detection | |
---@return table[] entries Array of quickfix entries with has_diagnostic field | |
local function get_qflist_entries() | |
local qflist = vim.fn.getqflist() | |
local entries = {} | |
for i, item in ipairs(qflist) do | |
local filename = vim.fn.bufname(item.bufnr) | |
if filename ~= "" then | |
local text = item.text or "" | |
local nr = item.nr or 0 | |
local has_diagnostic | |
if nr == -1 then | |
-- Search results: treat as files (show whole content) | |
has_diagnostic = false | |
else | |
-- Detection: if text ends with the filename, it's a file entry | |
-- If text is an error message that doesn't end with filename, it's lsp diagnostics | |
local escaped_filename = vim.pesc(filename) | |
local is_file_entry = text:match(escaped_filename .. "$") ~= nil | |
has_diagnostic = not is_file_entry | |
end | |
table.insert(entries, { | |
idx = i, | |
filename = filename, | |
lnum = item.lnum, | |
text = text, | |
type = item.type or "", | |
nr = nr, | |
has_diagnostic = has_diagnostic, | |
display = fmt("%s:%d: %s", vim.fn.fnamemodify(filename, ":."), item.lnum, text), | |
}) | |
end | |
end | |
return entries | |
end | |
---Extract symbols from a file using TreeSitter | |
---@param filepath string Path to the file | |
---@return table[]|nil symbols Array of symbols with start_line, end_line, name, kind | |
---@return string|nil content File content if successful | |
local function extract_file_symbols(filepath) | |
-- Only include function/method/class symbols for quickfix | |
local target_kinds = { "Function", "Method", "Class" } | |
return symbol_utils.extract_file_symbols(filepath, target_kinds) | |
end | |
---Find which symbol contains a diagnostic line | |
---@param diagnostic_line number Line number of the diagnostic | |
---@param symbols table[] Array of symbols to search through | |
---@return table|nil symbol The smallest containing symbol or nil | |
local function find_containing_symbol(diagnostic_line, symbols) | |
local best_symbol = nil | |
local best_size = math.huge | |
for _, symbol in ipairs(symbols) do | |
if symbol.start_line <= diagnostic_line and symbol.end_line >= diagnostic_line then | |
local symbol_size = symbol.end_line - symbol.start_line | |
if symbol_size < best_size then | |
best_symbol = symbol | |
best_size = symbol_size | |
end | |
end | |
end | |
return best_symbol | |
end | |
---Group diagnostics by proximity (within 5 lines) | |
---@param diagnostics table[] Array of diagnostic entries | |
---@return table[] groups Array of diagnostic groups | |
local function group_by_proximity(diagnostics) | |
if #diagnostics <= 1 then | |
return { diagnostics } | |
end | |
-- Sort by line number | |
table.sort(diagnostics, function(a, b) | |
return a.lnum < b.lnum | |
end) | |
local groups = {} | |
local current_group = { diagnostics[1] } | |
for i = 2, #diagnostics do | |
local prev_line = current_group[#current_group].lnum | |
local curr_line = diagnostics[i].lnum | |
if curr_line - prev_line <= 5 then | |
-- Close enough, add to current group | |
table.insert(current_group, diagnostics[i]) | |
else | |
-- Too far, start new group | |
table.insert(groups, current_group) | |
current_group = { diagnostics[i] } | |
end | |
end | |
table.insert(groups, current_group) | |
return groups | |
end | |
---Group diagnostics by symbol they belong to | |
---@param filepath string Path to the file | |
---@param diagnostics table[] Array of diagnostic entries | |
---@param file_content? string Optional file content to avoid re-reading | |
---@return table[] diagnostic_groups Array of groups with diagnostics and symbol info | |
---@return string|nil content File content if available | |
local function group_diagnostics_by_symbol(filepath, diagnostics, file_content) | |
local symbols, content = extract_file_symbols(filepath) | |
-- Use provided file_content if available to avoid re-reading | |
if not content and file_content then | |
content = file_content | |
end | |
-- If no symbols found, fallback to proximity grouping | |
if not symbols or #symbols == 0 then | |
return group_by_proximity(diagnostics), content | |
end | |
-- Group diagnostics by which symbol contains them | |
local symbol_groups = {} | |
local ungrouped_diagnostics = {} | |
for _, diagnostic in ipairs(diagnostics) do | |
local containing_symbol = find_containing_symbol(diagnostic.lnum, symbols) | |
if containing_symbol then | |
local symbol_key = | |
fmt("%s_%d_%d", containing_symbol.name, containing_symbol.start_line, containing_symbol.end_line) | |
if not symbol_groups[symbol_key] then | |
symbol_groups[symbol_key] = { | |
symbol = containing_symbol, | |
diagnostics = {}, | |
} | |
end | |
table.insert(symbol_groups[symbol_key].diagnostics, diagnostic) | |
else | |
-- Diagnostic not in any symbol | |
table.insert(ungrouped_diagnostics, diagnostic) | |
end | |
end | |
-- Convert to array format and sort diagnostics within each group | |
local result_groups = {} | |
for _, group_info in pairs(symbol_groups) do | |
table.sort(group_info.diagnostics, function(a, b) | |
return a.lnum < b.lnum | |
end) | |
table.insert(result_groups, { | |
diagnostics = group_info.diagnostics, | |
symbol = group_info.symbol, | |
}) | |
end | |
-- Handle ungrouped diagnostics with proximity grouping | |
if #ungrouped_diagnostics > 0 then | |
local ungrouped_groups = group_by_proximity(ungrouped_diagnostics) | |
for _, group in ipairs(ungrouped_groups) do | |
table.insert(result_groups, { | |
diagnostics = group, | |
symbol = nil, | |
}) | |
end | |
end | |
return result_groups, content | |
end | |
---Generate context for a group of diagnostics | |
---@param group_info table Group containing diagnostics and symbol | |
---@param file_content string Content of the file | |
---@param group_index number Index of this group | |
---@param total_groups number Total number of groups | |
---@return string context Formatted context with line numbers | |
local function generate_context_for_group(group_info, file_content, group_index, total_groups) | |
local lines = vim.split(file_content, "\n") | |
local diagnostics = group_info.diagnostics | |
local symbol = group_info.symbol | |
local context_start, context_end, header | |
if symbol then | |
-- Use symbol boundaries with padding | |
context_start = math.max(1, symbol.start_line - 3) | |
context_end = math.min(#lines, symbol.end_line + 3) | |
header = fmt("%s: %s (lines %d-%d)", symbol.kind, symbol.name, symbol.start_line, symbol.end_line) | |
else | |
-- Use line-based context around diagnostics | |
local start_line = diagnostics[1].lnum | |
local end_line = diagnostics[#diagnostics].lnum | |
context_start = math.max(1, start_line - 5) | |
context_end = math.min(#lines, end_line + 5) | |
header = fmt("lines %d-%d", context_start, context_end) | |
end | |
-- Build context lines | |
local context_lines = {} | |
for i = context_start, context_end do | |
table.insert(context_lines, fmt("%d: %s", i, lines[i])) | |
end | |
local context = table.concat(context_lines, "\n") | |
-- Add group header if multiple groups | |
if total_groups > 1 then | |
context = fmt("--- Group %d (%s) ---\n%s", group_index, header, context) | |
end | |
return context | |
end | |
---@class CodeCompanion.SlashCommand.Qflist: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---Create new quickfix slash command instance | |
---@param args CodeCompanion.SlashCommandArgs | |
---@return CodeCompanion.SlashCommand.Qflist | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
opts = args.opts, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the quickfix slash command | |
---@return nil | |
function SlashCommand:execute() | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
local entries = get_qflist_entries() | |
if #entries == 0 then | |
return log:warn("Quickfix list is empty") | |
end | |
self:output_entries(entries) | |
end | |
---Group quickfix entries by filename | |
---@param entries table[] Array of quickfix entries | |
---@return table files Grouped files with diagnostics | |
local function group_entries_by_file(entries) | |
local files = {} | |
for _, entry in ipairs(entries) do | |
if not files[entry.filename] then | |
files[entry.filename] = { diagnostics = {}, has_diagnostics = false } | |
end | |
if entry.has_diagnostic then | |
table.insert(files[entry.filename].diagnostics, { | |
lnum = entry.lnum, | |
text = entry.text, | |
type = entry.type, | |
}) | |
files[entry.filename].has_diagnostics = true | |
end | |
end | |
return files | |
end | |
---Process a single file and generate description for chat | |
---@param filepath string Path to the file | |
---@param file_data table File data with diagnostics | |
---@return string|nil description Formatted description for chat or nil if failed | |
---@return string id Reference ID for the file | |
local function process_single_file(filepath, file_data) | |
local relative_path = vim.fn.fnamemodify(filepath, ":.") | |
local ft = vim.filetype.match({ filename = filepath }) | |
local id = "<quickfix>" .. relative_path .. "</quickfix>" | |
-- Read file once | |
local ok, file_content = pcall(function() | |
return path.new(filepath):read() | |
end) | |
if not ok then | |
log:warn("Could not read file: %s", filepath) | |
return nil, id | |
end | |
local content, description | |
if file_data.has_diagnostics then | |
local lines = vim.split(file_content, "\n") | |
-- Small file: show everything with simple diagnostic summary | |
if #lines < 100 then | |
content = file_content | |
local diagnostic_summary = {} | |
for _, diagnostic in ipairs(file_data.diagnostics) do | |
table.insert(diagnostic_summary, fmt("Line %d: %s", diagnostic.lnum, diagnostic.text)) | |
end | |
description = fmt( | |
[[<attachment filepath="%s">Here is the content from the file with quickfix entries (small file, showing all content): | |
%s | |
```%s | |
%s | |
``` | |
</attachment>]], | |
relative_path, | |
table.concat(diagnostic_summary, "\n"), | |
ft, | |
content | |
) | |
else | |
-- Large file: use smart grouping and context extraction | |
local diagnostic_groups, _ = group_diagnostics_by_symbol(filepath, file_data.diagnostics, file_content) | |
-- Generate diagnostic summary with groups | |
local diagnostic_summary = {} | |
for group_idx, group_info in ipairs(diagnostic_groups) do | |
if #diagnostic_groups > 1 then | |
if group_info.symbol then | |
table.insert( | |
diagnostic_summary, | |
fmt("## Group %d (%s: %s):", group_idx, group_info.symbol.kind, group_info.symbol.name) | |
) | |
else | |
table.insert(diagnostic_summary, fmt("## Group %d:", group_idx)) | |
end | |
end | |
for _, diagnostic in ipairs(group_info.diagnostics) do | |
table.insert(diagnostic_summary, fmt("Line %d: %s", diagnostic.lnum, diagnostic.text)) | |
end | |
if #diagnostic_groups > 1 then | |
table.insert(diagnostic_summary, "") -- Empty line between groups | |
end | |
end | |
-- Generate context for each group | |
local contexts = {} | |
for i, group_info in ipairs(diagnostic_groups) do | |
local group_context = generate_context_for_group(group_info, file_content, i, #diagnostic_groups) | |
table.insert(contexts, group_context) | |
end | |
content = table.concat(contexts, "\n\n") | |
description = fmt( | |
[[<attachment filepath="%s">Here is the content from the file with quickfix entries: | |
%s | |
```%s | |
%s | |
``` | |
</attachment>]], | |
relative_path, | |
table.concat(diagnostic_summary, "\n"), | |
ft, | |
content | |
) | |
end | |
else | |
-- File-only entries | |
content = file_content | |
description = fmt( | |
[[<attachment filepath="%s">Here is the content from the file: | |
```%s | |
%s | |
``` | |
</attachment>]], | |
relative_path, | |
ft, | |
content | |
) | |
end | |
return description, id | |
end | |
---Output quickfix entries to chat | |
---@param entries table[] Array of quickfix entries | |
---@return nil | |
function SlashCommand:output_entries(entries) | |
local files = group_entries_by_file(entries) | |
-- Output each file | |
for filepath, file_data in pairs(files) do | |
local description, id = process_single_file(filepath, file_data) | |
if description then | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = description, | |
}, { reference = id, visible = false }) | |
self.Chat.references:add({ | |
id = id, | |
path = filepath, | |
source = "codecompanion.strategies.chat.slash_commands.qflist", | |
}) | |
end | |
end | |
util.notify(fmt("Added %d file(s) from quickfix list to chat", vim.tbl_count(files))) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/symbols.lua | |
--- | |
--[[ | |
Uses Tree-sitter to parse a given file and extract symbol types and names. Then | |
displays those symbols in the chat buffer as references. To support tools | |
and agents, start and end lines for the symbols are also output. | |
Heavily modified from the awesome Aerial.nvim plugin by stevearc: | |
https://github.com/stevearc/aerial.nvim/blob/master/lua/aerial/backends/treesitter/init.lua | |
--]] | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local path = require("plenary.path") | |
local symbol_utils = require("codecompanion.strategies.chat.helpers") | |
local util = require("codecompanion.utils") | |
local fmt = string.format | |
local get_node_text = vim.treesitter.get_node_text --[[@type function]] | |
local CONSTANTS = { | |
NAME = "Symbols", | |
PROMPT = "Select symbol(s)", | |
} | |
---Get the range of two nodes | |
---@param start_node TSNode | |
---@param end_node TSNode | |
local function range_from_nodes(start_node, end_node) | |
local row, col = start_node:start() | |
local end_row, end_col = end_node:end_() | |
return { | |
lnum = row + 1, | |
end_lnum = end_row + 1, | |
col = col, | |
end_col = end_col, | |
} | |
end | |
---Return when no symbols query exists | |
local function no_query(ft) | |
util.notify( | |
fmt("There are no Tree-sitter symbol queries for `%s` files yet. Please consider making a PR", ft), | |
vim.log.levels.WARN | |
) | |
end | |
---Return when no symbols have been found | |
local function no_symbols() | |
util.notify("No symbols found in the given file", vim.log.levels.WARN) | |
end | |
local providers = { | |
---The default provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
default = function(SlashCommand) | |
local default = require("codecompanion.providers.slash_commands.default") | |
return default | |
.new({ | |
output = function(selection) | |
SlashCommand:output({ relative_path = selection.relative_path, path = selection.path }) | |
end, | |
SlashCommand = SlashCommand, | |
title = CONSTANTS.PROMPT, | |
}) | |
:find_files() | |
:display() | |
end, | |
---The Snacks.nvim provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
snacks = function(SlashCommand) | |
local snacks = require("codecompanion.providers.slash_commands.snacks") | |
snacks = snacks.new({ | |
title = CONSTANTS.PROMPT .. ": ", | |
output = function(selection) | |
return SlashCommand:output({ | |
relative_path = selection.file, | |
path = vim.fs.joinpath(selection.cwd, selection.file), | |
}) | |
end, | |
}) | |
snacks.provider.picker.pick({ | |
source = "files", | |
prompt = snacks.title, | |
confirm = snacks:display(), | |
main = { file = false, float = true }, | |
}) | |
end, | |
---The Telescope provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
telescope = function(SlashCommand) | |
local telescope = require("codecompanion.providers.slash_commands.telescope") | |
telescope = telescope.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selection) | |
return SlashCommand:output({ | |
relative_path = selection[1], | |
path = selection.path, | |
}) | |
end, | |
}) | |
telescope.provider.find_files({ | |
prompt_title = telescope.title, | |
attach_mappings = telescope:display(), | |
}) | |
end, | |
---The Mini.Pick provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
mini_pick = function(SlashCommand) | |
local mini_pick = require("codecompanion.providers.slash_commands.mini_pick") | |
mini_pick = mini_pick.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
mini_pick.provider.builtin.files( | |
{}, | |
mini_pick:display(function(selected) | |
return { | |
path = selected, | |
relative_path = selected, | |
} | |
end) | |
) | |
end, | |
---The fzf-lua provider | |
---@param SlashCommand CodeCompanion.SlashCommand | |
---@return nil | |
fzf_lua = function(SlashCommand) | |
local fzf = require("codecompanion.providers.slash_commands.fzf_lua") | |
fzf = fzf.new({ | |
title = CONSTANTS.PROMPT, | |
output = function(selected) | |
return SlashCommand:output(selected) | |
end, | |
}) | |
fzf.provider.files(fzf:display(function(selected, opts) | |
local file = fzf.provider.path.entry_to_file(selected, opts) | |
return { | |
relative_path = file.stripped, | |
path = file.path, | |
} | |
end)) | |
end, | |
} | |
---@class CodeCompanion.SlashCommand.Symbols: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@return nil | |
function SlashCommand:execute(SlashCommands) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
return SlashCommands:set_provider(self, providers) | |
end | |
---Output from the slash command in the chat buffer | |
---@param selected table The selected item from the provider { relative_path = string, path = string } | |
---@param opts? table | |
---@return nil | |
function SlashCommand:output(selected, opts) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
opts = opts or {} | |
local ft = vim.filetype.match({ filename = selected.path }) | |
local symbols, content = symbol_utils.extract_file_symbols(selected.path) | |
if not symbols then | |
return no_query(ft) | |
end | |
local symbol_descriptions = {} | |
local kinds = { | |
"Import", | |
"Enum", | |
"Module", | |
"Class", | |
"Struct", | |
"Interface", | |
"Method", | |
"Function", | |
} | |
for _, symbol in ipairs(symbols) do | |
if vim.tbl_contains(kinds, symbol.kind) then | |
table.insert( | |
symbol_descriptions, | |
fmt("- %s: `%s` (from line %s to %s)", symbol.kind:lower(), symbol.name, symbol.start_line, symbol.end_line) | |
) | |
end | |
end | |
if #symbol_descriptions == 0 then | |
return no_symbols() | |
end | |
local id = "<symbols>" .. (selected.relative_path or selected.path) .. "</symbols>" | |
content = table.concat(symbol_descriptions, "\n") | |
-- Workspaces allow the user to set their own custom description which should take priority | |
local description | |
if selected.description then | |
description = fmt( | |
[[%s | |
```%s | |
%s | |
```]], | |
selected.description, | |
ft, | |
content | |
) | |
else | |
description = fmt( | |
[[Here is a symbolic outline of the file `%s` (with filetype `%s`). I've also included the line numbers that each symbol starts and ends on in the file: | |
%s | |
Prompt the user if you need to see more than the symbolic outline. | |
]], | |
selected.relative_path or selected.path, | |
ft, | |
content | |
) | |
end | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = description, | |
}, { reference = id, visible = false }) | |
self.Chat.references:add({ | |
source = "slash_command", | |
name = "symbols", | |
id = id, | |
}) | |
if opts.silent then | |
return | |
end | |
util.notify(fmt("Added the symbols for `%s` to the chat", vim.fn.fnamemodify(selected.relative_path, ":t"))) | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/terminal.lua | |
--- | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local util = require("codecompanion.utils") | |
local CONSTANTS = { | |
NAME = "Terminal Output", | |
} | |
---@class CodeCompanion.SlashCommand.Terminal: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommand | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
}, { __index = SlashCommand }) | |
return self | |
end | |
local _terminal_data = {} | |
---Execute the slash command | |
---@return nil | |
function SlashCommand:execute() | |
local bufnr = _G.codecompanion_last_terminal | |
if not bufnr then | |
return util.notify("No recent terminal buffer found", vim.log.levels.WARN) | |
end | |
local start_line = 0 | |
if _terminal_data[bufnr] then | |
start_line = _terminal_data[bufnr].lines - 3 -- Account for new prompt lines | |
end | |
local ok, content = pcall(function() | |
return vim.api.nvim_buf_get_lines(bufnr, start_line, -1, false) | |
end) | |
if not ok then | |
return log:error("Failed to get terminal output") | |
end | |
_terminal_data[bufnr] = { | |
lines = #content + (_terminal_data[bufnr] and _terminal_data[bufnr].lines or 0), | |
timestamp = os.time(), | |
} | |
local Chat = self.Chat | |
Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = string.format( | |
[[Here is the latest output from terminal `%s`: | |
``` | |
%s | |
```]], | |
bufnr, | |
table.concat(content, "\n") | |
), | |
}, { visible = false }) | |
util.notify("Terminal output added to chat") | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/slash_commands/workspace.lua | |
--- | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local slash_commands = require("codecompanion.strategies.chat.slash_commands") | |
local util = require("codecompanion.utils") | |
local fmt = string.format | |
local CONSTANTS = { | |
NAME = "Workspace", | |
PROMPT = "Select a workspace group", | |
WORKSPACE_FILE = vim.fs.joinpath(vim.fn.getcwd(), "codecompanion-workspace.json"), | |
} | |
---Output a list of files in the group | |
---@param group table | |
---@param workspace table | |
---@return string | |
local function get_file_list(group, workspace) | |
local items = {} | |
if group.data and workspace and workspace.data then | |
for _, item in ipairs(group.data) do | |
local resource = workspace.data[item] | |
if resource and resource.path then | |
table.insert(items, "- " .. resource.path) | |
end | |
end | |
end | |
if vim.tbl_count(items) == 0 then | |
return "" | |
end | |
if group.vars then | |
util.replace_placeholders(items, group.vars) | |
end | |
return table.concat(items, "\n") | |
end | |
---Replace variables in a string | |
---@param workspace table | |
---@param group table | |
---@param str string | |
---@return string | |
local function replace_vars(workspace, group, str) | |
local replaced_vars = {} | |
-- Vars from the top level can be overwritten, so they come first | |
if workspace.vars then | |
vim.iter(workspace.vars):each(function(k, v) | |
replaced_vars[k] = v | |
end) | |
end | |
if group.vars then | |
vim.iter(group.vars):each(function(k, v) | |
replaced_vars[k] = v | |
end) | |
end | |
-- Add the builtin group level and workspace vars | |
replaced_vars["workspace_name"] = workspace.name | |
replaced_vars["group_name"] = group.name | |
return util.replace_placeholders(str, replaced_vars) | |
end | |
---Add the description of the group to the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@param workspace table | |
---@param group { name: string, description: string, files: table?, symbols: table? } | |
local function add_group_description(chat, workspace, group) | |
chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = replace_vars(workspace, group, group.description), | |
}, { visible = false }) | |
end | |
---@class CodeCompanion.SlashCommand.Workspace: CodeCompanion.SlashCommand | |
local SlashCommand = {} | |
---@param args CodeCompanion.SlashCommandArgs | |
function SlashCommand.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
context = args.context, | |
opts = args.opts or {}, | |
}, { __index = SlashCommand }) | |
self.workspace = {} | |
return self | |
end | |
---Open and read the contents of the workspace file | |
---@param path? string | |
---@return table | |
function SlashCommand:read_workspace_file(path) | |
if not path then | |
path = CONSTANTS.WORKSPACE_FILE | |
end | |
if not path then | |
path = vim.fs.joinpath(vim.fn.getcwd(), "codecompanion-workspace.json") | |
CONSTANTS.WORKSPACE_FILE = vim.fs.joinpath(vim.fn.getcwd(), "codecompanion-workspace.json") | |
end | |
if not vim.uv.fs_stat(path) then | |
return log:warn(fmt("Could not find a workspace file at `%s`", path)) | |
end | |
local short_path = vim.fn.fnamemodify(path, ":t") | |
-- Read the file | |
local content | |
local f = io.open(path, "r") | |
if f then | |
content = f:read("*a") | |
f:close() | |
end | |
if content == "" or content == nil then | |
return log:warn(fmt("No content to read in the `%s` file", short_path)) | |
end | |
-- Parse the JSON | |
local ok, json = pcall(function() | |
return vim.json.decode(content) | |
end) | |
if not ok then | |
return log:error(fmt("Invalid JSON in the `%s` file", short_path)) | |
end | |
return json | |
end | |
---Add an item from the data section to the chat buffer | |
---@param group table | |
---@param item string | |
function SlashCommand:add_to_chat(group, item) | |
local resource = self.workspace.data[item] | |
if not resource then | |
return log:warn("Could not find '%s' in the workspace file", item) | |
end | |
-- Apply group variables to path | |
local path = replace_vars(self.workspace, group, resource.path) | |
-- Apply built-in variables to description | |
local description = resource.description | |
if description then | |
local builtin = { | |
cwd = vim.fn.getcwd(), | |
filename = vim.fn.fnamemodify(path, ":t"), | |
path = path, | |
} | |
-- Replace variables from the user's custom declarations as well as the builtin ones | |
description = util.replace_placeholders(replace_vars(self.workspace, group, description), builtin) | |
end | |
-- Extract options if present | |
local opts = resource.opts or {} | |
return slash_commands.references(self.Chat, resource.type, { path = path, description = description, opts = opts }) | |
end | |
---Execute the slash command | |
---@param SlashCommands CodeCompanion.SlashCommands | |
---@param opts? table | |
---@return nil | |
function SlashCommand:execute(SlashCommands, opts) | |
if not config.can_send_code() and (self.config.opts and self.config.opts.contains_code) then | |
return log:warn("Sending of code has been disabled") | |
end | |
self.workspace = self:read_workspace_file() | |
-- Get the group names | |
local groups = {} | |
vim.iter(self.workspace.groups):each(function(group) | |
table.insert(groups, group.name) | |
end) | |
--TODO: Add option to add all groups | |
-- if vim.tbl_count(groups) > 1 then | |
-- table.insert(groups, 1, "All") | |
-- end | |
-- Let the user select a group | |
vim.ui.select(groups, { kind = "codecompanion.nvim", prompt = "Select a Group to load" }, function(choice) | |
if not choice then | |
return nil | |
end | |
return self:output(choice, opts) | |
end) | |
end | |
---Add the selected group to the chat buffer | |
---@param selected_group string | |
---@param opts? table | |
function SlashCommand:output(selected_group, opts) | |
local group = vim.tbl_filter(function(g) | |
return g.name == selected_group | |
end, self.workspace.groups)[1] | |
if group.opts then | |
if group.opts.remove_config_system_prompt then | |
self.Chat:remove_tagged_message("from_config") | |
end | |
end | |
-- Add the system prompts | |
if self.workspace.system_prompt then | |
self.Chat:add_system_prompt( | |
replace_vars(self.workspace, group, self.workspace.system_prompt), | |
{ visible = false, tag = self.workspace.name .. " // Workspace" } | |
) | |
end | |
if group.system_prompt then | |
self.Chat:add_system_prompt( | |
replace_vars(self.workspace, group, group.system_prompt), | |
{ visible = false, tag = group.name .. " // Workspace Group" } | |
) | |
end | |
-- Add the description as a user message | |
if group.description then | |
add_group_description(self.Chat, self.workspace, group) | |
end | |
if group.data and self.workspace.data then | |
for _, data_item in ipairs(group.data) do | |
self:add_to_chat(group, data_item) | |
end | |
end | |
end | |
return SlashCommand | |
--- | |
File: /lua/codecompanion/strategies/chat/variables/buffer.lua | |
--- | |
local buf_utils = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local reserved_params = { | |
"pin", | |
"watch", | |
} | |
---@class CodeCompanion.Variable.Buffer: CodeCompanion.Variable | |
local Variable = {} | |
---@param args CodeCompanion.VariableArgs | |
function Variable.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
params = args.params, | |
}, { __index = Variable }) | |
return self | |
end | |
---Add the contents of the current buffer to the chat | |
---@param selected table | |
---@param opts? table | |
---@return nil | |
function Variable:output(selected, opts) | |
selected = selected or {} | |
opts = opts or {} | |
local bufnr = selected.bufnr or self.Chat.context.bufnr | |
local params = selected.params or self.params | |
if params and not vim.tbl_contains(reserved_params, params) then | |
return log:warn("Invalid parameter for buffer variable: %s", params) | |
end | |
local message = "User's current visible code in a file (including line numbers). This should be the main focus" | |
if opts.pin then | |
message = "Here is the updated file content (including line numbers)" | |
end | |
local ok, content, id, _ = pcall(buf_utils.format_for_llm, { | |
bufnr = bufnr, | |
path = buf_utils.get_info(bufnr).path, | |
}, { message = message }) | |
if not ok then | |
return log:warn(content) | |
end | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = content, | |
}, { reference = id, tag = "variable", visible = false }) | |
if opts.pin then | |
return | |
end | |
self.Chat.references:add({ | |
bufnr = bufnr, | |
params = params, | |
id = id, | |
opts = { | |
pinned = (params and params == "pin"), | |
watched = (params and params == "watch"), | |
}, | |
source = "codecompanion.strategies.chat.variables.buffer", | |
}) | |
end | |
---Replace the variable in the message | |
---@param message string | |
---@param bufnr number | |
---@return string | |
function Variable.replace(prefix, message, bufnr) | |
local bufname = buf_utils.name_from_bufnr(bufnr) | |
local replacement = "file `" .. bufname .. "` (with buffer number: " .. bufnr .. ")" | |
local result = message:gsub(prefix .. "buffer{[^}]*}", replacement) | |
result = result:gsub(prefix .. "buffer", replacement) | |
return result | |
end | |
return Variable | |
--- | |
File: /lua/codecompanion/strategies/chat/variables/init.lua | |
--- | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local regex = require("codecompanion.utils.regex") | |
local CONSTANTS = { | |
PREFIX = "#", | |
} | |
---Check a message for any parameters that have been given to the variable | |
---@param message table | |
---@param var string | |
---@return string|nil | |
local function find_params(message, var) | |
local pattern = CONSTANTS.PREFIX .. var .. "{([^}]+)}" | |
local params = message.content:match(pattern) | |
if params then | |
log:trace("Params found for variable: %s", params) | |
return params | |
end | |
return nil | |
end | |
---@param chat CodeCompanion.Chat | |
---@param var_config table | |
---@param params? string | |
---@return table | |
local function resolve(chat, var_config, params) | |
if type(var_config.callback) == "string" then | |
local splits = vim.split(var_config.callback, ".", { plain = true }) | |
local path = table.concat(splits, ".", 1, #splits - 1) | |
local variable = splits[#splits] | |
local ok, module = pcall(require, "codecompanion." .. path .. "." .. variable) | |
local init = { | |
Chat = chat, | |
config = var_config, | |
params = params or (var_config.opts and var_config.opts.default_params), | |
} | |
-- User is using a custom callback | |
if not ok then | |
log:trace("Calling variable: %s", path .. "." .. variable) | |
return require(path .. "." .. variable).new(init):output() | |
end | |
log:trace("Calling variable: %s", path .. "." .. variable) | |
return module.new(init):output() | |
end | |
return require("codecompanion.strategies.chat.variables.user") | |
.new({ | |
Chat = chat, | |
config = var_config, | |
params = params, | |
}) | |
:output() | |
end | |
---@class CodeCompanion.Variables | |
local Variables = {} | |
function Variables.new() | |
local self = setmetatable({ | |
vars = config.strategies.chat.variables, | |
}, { __index = Variables }) | |
return self | |
end | |
---Creates a regex pattern to match a variable in a message | |
---@param var string The variable name to create a pattern for | |
---@return string The compiled regex pattern | |
function Variables:_pattern(var) | |
return CONSTANTS.PREFIX .. var .. "\\(\\s\\|$\\|{[^}]*}\\)" | |
end | |
---Check a message for a variable | |
---@param message table | |
---@return table|nil | |
function Variables:find(message) | |
if not message.content then | |
return nil | |
end | |
local found = {} | |
for var, _ in pairs(self.vars) do | |
if regex.find(message.content, self:_pattern(var)) then | |
table.insert(found, var) | |
end | |
end | |
if #found == 0 then | |
return nil | |
end | |
return found | |
end | |
---Parse a message to detect if it references any variables | |
---@param chat CodeCompanion.Chat | |
---@param message table | |
---@return boolean | |
function Variables:parse(chat, message) | |
local vars = self:find(message) | |
if vars then | |
for _, var in ipairs(vars) do | |
local var_config = self.vars[var] | |
log:debug("Variable found: %s", var) | |
var_config["name"] = var | |
if (var_config.opts and var_config.opts.contains_code) and not config.can_send_code() then | |
log:warn("Sending of code has been disabled") | |
goto continue | |
end | |
local params = nil | |
if var_config.opts and var_config.opts.has_params then | |
params = find_params(message, var) | |
end | |
resolve(chat, var_config, params) | |
::continue:: | |
end | |
return true | |
end | |
return false | |
end | |
---Replace a variable in a given message | |
---@param message string | |
---@param bufnr number | |
---@return string | |
function Variables:replace(message, bufnr) | |
for var, _ in pairs(self.vars) do | |
-- The buffer variable is unique because it can take parameters which need to be handled | |
-- TODO: If more variables have parameters in the future we'll extract this | |
if var:match("^buffer") then | |
message = require("codecompanion.strategies.chat.variables.buffer").replace(CONSTANTS.PREFIX, message, bufnr) | |
else | |
message = vim.trim(regex.replace(message, self:_pattern(var), "")) | |
end | |
end | |
return message | |
end | |
return Variables | |
--- | |
File: /lua/codecompanion/strategies/chat/variables/lsp.lua | |
--- | |
local buf_utils = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
---@class CodeCompanion.Variable.LSP: CodeCompanion.Variable | |
local Variable = {} | |
---@param args CodeCompanion.VariableArgs | |
function Variable.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
params = args.params, | |
}, { __index = Variable }) | |
return self | |
end | |
---Return all of the LSP information and code for the current buffer | |
---@return nil | |
function Variable:output() | |
local severity = { | |
[1] = "ERROR", | |
[2] = "WARNING", | |
[3] = "INFORMATION", | |
[4] = "HINT", | |
} | |
local bufnr = self.Chat.context.bufnr | |
local diagnostics = vim.diagnostic.get(bufnr, { | |
severity = { min = vim.diagnostic.severity.HINT }, | |
}) | |
-- Add code to the diagnostics | |
for _, diagnostic in ipairs(diagnostics) do | |
for i = diagnostic.lnum, diagnostic.end_lnum do | |
if not diagnostic.lines then | |
diagnostic.lines = {} | |
end | |
table.insert( | |
diagnostic.lines, | |
string.format("%d: %s", i + 1, vim.trim(buf_utils.get_content(bufnr, { i, i + 1 }))) | |
) | |
end | |
end | |
local formatted = {} | |
for _, diagnostic in ipairs(diagnostics) do | |
table.insert( | |
formatted, | |
string.format( | |
[[ | |
Severity: %s | |
LSP Message: %s | |
Code: | |
```%s | |
%s | |
``` | |
]], | |
severity[diagnostic.severity], | |
diagnostic.message, | |
self.Chat.context.filetype, | |
table.concat(diagnostic.lines, "\n") | |
) | |
) | |
end | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = table.concat(formatted, "\n\n"), | |
}, { tag = "variable", visible = false }) | |
end | |
return Variable | |
--- | |
File: /lua/codecompanion/strategies/chat/variables/user.lua | |
--- | |
local buf_utils = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
---@class CodeCompanion.Variable.User: CodeCompanion.Variable | |
local Variable = {} | |
---@param args CodeCompanion.VariableArgs | |
function Variable.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
params = args.params, | |
}, { __index = Variable }) | |
return self | |
end | |
---Return the user's custom variable | |
---@return nil | |
function Variable:output() | |
local id = "<var>" .. self.config.name .. "</var>" | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = self.config.callback(self), | |
}, { reference = id, tag = "variable", visible = false }) | |
self.Chat.references:add({ | |
bufnr = self.Chat.bufnr, | |
id = id, | |
source = "codecompanion.strategies.chat.variables.user", | |
}) | |
end | |
return Variable | |
--- | |
File: /lua/codecompanion/strategies/chat/variables/viewport.lua | |
--- | |
local buf_utils = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
---@class CodeCompanion.Variable.ViewPort: CodeCompanion.Variable | |
local Variable = {} | |
---@param args CodeCompanion.VariableArgs | |
function Variable.new(args) | |
local self = setmetatable({ | |
Chat = args.Chat, | |
config = args.config, | |
params = args.params, | |
}, { __index = Variable }) | |
return self | |
end | |
---Return all of the visible lines in the editor's viewport | |
---@return nil | |
function Variable:output() | |
local buf_lines = buf_utils.get_visible_lines() | |
local content = buf_utils.format_viewport_for_llm(buf_lines) | |
self.Chat:add_message({ | |
role = config.constants.USER_ROLE, | |
content = content, | |
}, { tag = "variable", visible = false }) | |
end | |
return Variable | |
--- | |
File: /lua/codecompanion/strategies/chat/debug.lua | |
--- | |
local buf_utils = require("codecompanion.utils.buffers") | |
local config = require("codecompanion.config") | |
local ui = require("codecompanion.utils.ui") | |
local util = require("codecompanion.utils") | |
local api = vim.api | |
---@param bufnr number | |
---@param opts? table | |
local function _get_settings_key(bufnr, opts) | |
opts = vim.tbl_extend("force", opts or {}, { | |
lang = "lua", | |
}) | |
local node = vim.treesitter.get_node(opts) | |
local current = node | |
local in_settings = false | |
while current do | |
if current:type() == "assignment_statement" then | |
local name_node = current:named_child(0) | |
if name_node and vim.treesitter.get_node_text(name_node, bufnr) == "settings" then | |
in_settings = true | |
break | |
end | |
end | |
current = current:parent() | |
end | |
if not in_settings then | |
return | |
end | |
while node do | |
if node:type() == "field" then | |
local key_node = node:named_child(0) | |
if key_node and key_node:type() == "identifier" then | |
local key_name = vim.treesitter.get_node_text(key_node, bufnr) | |
return key_name, node | |
end | |
end | |
node = node:parent() | |
end | |
end | |
---Extract the settings and messages from the buffer | |
local function get_buffer_content(lines) | |
local content = table.concat(lines, "\n") | |
local env = {} | |
local chunk, err = load( | |
"local settings, messages; " .. content .. " return {settings=settings, messages=messages}", | |
"buffer", | |
"t", | |
env | |
) | |
if not chunk then | |
return error("Failed to parse buffer: " .. (err or "unknown error")) | |
end | |
local result = chunk() | |
return result.settings, result.messages | |
end | |
---@class CodeCompanion.Chat.Debug | |
---@field chat CodeCompanion.Chat | |
---@field settings table | |
---@field aug number | |
local Debug = {} | |
function Debug.new(args) | |
local self = setmetatable({ | |
chat = args.chat, | |
settings = args.settings, | |
}, { __index = Debug }) | |
return self | |
end | |
---Render the settings and messages | |
---@return CodeCompanion.Chat.Debug | |
function Debug:render() | |
local models | |
local adapter = vim.deepcopy(self.chat.adapter) | |
self.adapter = adapter | |
local bufname = buf_utils.name_from_bufnr(self.chat.context.bufnr) | |
-- Get the current settings from the chat buffer rather than making new ones | |
local current_settings = self.settings or {} | |
if type(adapter.schema.model.choices) == "function" then | |
models = adapter.schema.model.choices(adapter) | |
else | |
models = adapter.schema.model.choices | |
end | |
local lines = {} | |
table.insert(lines, '-- Adapter: "' .. adapter.formatted_name .. '"') | |
table.insert(lines, "-- Buffer: " .. self.chat.bufnr) | |
table.insert(lines, '-- Context: "' .. bufname .. '" (' .. self.chat.context.bufnr .. ")") | |
-- Add settings | |
if not config.display.chat.show_settings then | |
table.insert(lines, "") | |
local keys = {} | |
-- Collect all settings keys including those with nil defaults | |
for key, _ in pairs(self.settings) do | |
table.insert(keys, key) | |
end | |
-- Add any schema keys that have an explicit nil default | |
for key, schema_value in pairs(adapter.schema) do | |
if schema_value.default == nil and not vim.tbl_contains(keys, key) then | |
table.insert(keys, key) | |
end | |
end | |
table.sort(keys, function(a, b) | |
local a_order = adapter.schema[a] and adapter.schema[a].order or 999 | |
local b_order = adapter.schema[b] and adapter.schema[b].order or 999 | |
if a_order == b_order then | |
return a < b -- alphabetical sort as fallback | |
end | |
return a_order < b_order | |
end) | |
table.insert(lines, "local settings = {") | |
for _, key in ipairs(keys) do | |
local val = self.settings[key] | |
local is_nil = adapter.schema[key] and adapter.schema[key].default == nil | |
if key == "model" then | |
local other_models = " -- " | |
vim.iter(models):each(function(model, model_name) | |
if type(model) == "number" then | |
model = model_name | |
end | |
if model ~= val then | |
other_models = other_models .. '"' .. model .. '", ' | |
end | |
end) | |
if type(val) == "function" then | |
val = val(self.adapter) | |
end | |
if vim.tbl_count(models) > 1 then | |
table.insert(lines, " " .. key .. ' = "' .. val .. '", ' .. other_models) | |
else | |
table.insert(lines, " " .. key .. ' = "' .. val .. '",') | |
end | |
elseif is_nil and current_settings[key] == nil then | |
table.insert(lines, " " .. key .. " = nil,") | |
elseif type(val) == "number" or type(val) == "boolean" then | |
table.insert(lines, " " .. key .. " = " .. tostring(val) .. ",") | |
elseif type(val) == "string" then | |
table.insert(lines, " " .. key .. ' = "' .. val .. '",') | |
elseif type(val) == "function" then | |
local expanded_val = val(self.adapter) | |
if type(expanded_val) == "number" or type(expanded_val) == "boolean" then | |
table.insert(lines, " " .. key .. " = " .. tostring(val(self.adapter)) .. ",") | |
else | |
table.insert(lines, " " .. key .. ' = "' .. tostring(val(self.adapter)) .. '",') | |
end | |
else | |
table.insert(lines, " " .. key .. " = " .. vim.inspect(val)) | |
end | |
end | |
table.insert(lines, "}") | |
end | |
-- Add messages | |
if vim.tbl_count(self.chat.messages) > 0 then | |
table.insert(lines, "") | |
table.insert(lines, "local messages = ") | |
local messages = vim.inspect(self.chat.messages) | |
for line in messages:gmatch("[^\r\n]+") do | |
table.insert(lines, line) | |
end | |
end | |
self.bufnr = api.nvim_create_buf(false, true) | |
-- Set the keymaps as per the user's chat buffer config | |
local maps = {} | |
local config_maps = vim.deepcopy(config.strategies.chat.keymaps) | |
maps["save"] = config_maps["send"] | |
maps["save"].callback = "save" | |
maps["save"].description = "Save debug window content" | |
maps["close"] = config_maps["close"] | |
maps["close"].callback = "close" | |
maps["close"].description = "Close debug window" | |
require("codecompanion.utils.keymaps") | |
.new({ | |
bufnr = self.bufnr, | |
callbacks = function() | |
local M = {} | |
M.save = function() | |
return self:save() | |
end | |
M.close = function() | |
return self:close() | |
end | |
return M | |
end, | |
data = nil, | |
keymaps = maps, | |
}) | |
:set() | |
local window = vim.deepcopy(config.display.chat.window) | |
if type(config.display.chat.debug_window.height) == "function" then | |
window.height = config.display.chat.debug_window.height() | |
else | |
window.height = config.display.chat.debug_window.height | |
end | |
if type(config.display.chat.debug_window.width) == "function" then | |
window.width = config.display.chat.debug_window.width() | |
else | |
window.width = config.display.chat.debug_window.width | |
end | |
ui.create_float(lines, { | |
bufnr = self.bufnr, | |
filetype = "lua", | |
ignore_keymaps = true, | |
relative = "editor", | |
title = "Debug Chat", | |
window = window, | |
opts = { | |
wrap = true, | |
}, | |
}) | |
self:setup_window() | |
return self | |
end | |
---Setup the debug window | |
---@return nil | |
function Debug:setup_window() | |
self.aug = api.nvim_create_augroup("codecompanion.debug" .. ":" .. self.bufnr, { | |
clear = true, | |
}) | |
api.nvim_create_autocmd("CursorMoved", { | |
group = self.aug, | |
buffer = self.bufnr, | |
desc = "Show settings information in the CodeCompanion chat buffer", | |
callback = function() | |
local key_name, node = _get_settings_key(self.bufnr) | |
if not key_name or not node then | |
return vim.diagnostic.set(config.INFO_NS, self.bufnr, {}) | |
end | |
local key_schema = self.adapter.schema[key_name] | |
if key_schema and key_schema.desc then | |
local lnum, col, end_lnum, end_col = node:range() | |
local diagnostic = { | |
lnum = lnum, | |
col = col, | |
end_lnum = end_lnum, | |
end_col = end_col, | |
severity = vim.diagnostic.severity.INFO, | |
message = key_schema.desc, | |
} | |
vim.diagnostic.set(config.INFO_NS, self.bufnr, { diagnostic }) | |
end | |
end, | |
}) | |
api.nvim_create_autocmd("BufWrite", { | |
group = self.aug, | |
buffer = self.bufnr, | |
desc = "Save the contents of the debug window to the chat buffer", | |
callback = function() | |
return self:save() | |
end, | |
}) | |
api.nvim_create_autocmd({ "BufUnload", "WinClosed" }, { | |
group = self.aug, | |
buffer = self.bufnr, | |
desc = "Clear the autocmds in the debug window", | |
callback = function() | |
return self:close() | |
end, | |
}) | |
end | |
---Save the contents of the debug window to the chat buffer | |
function Debug:save() | |
local contents = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) | |
local settings, messages = get_buffer_content(contents) | |
if not settings and not messages then | |
return | |
end | |
if settings then | |
self.chat:apply_settings(settings) | |
end | |
if messages then | |
self.chat.messages = messages | |
end | |
util.notify("Updated the settings and messages") | |
end | |
---Function to run when the debug chat is closed | |
---@return nil | |
function Debug:close() | |
if self.aug then | |
api.nvim_clear_autocmds({ group = self.aug }) | |
end | |
api.nvim_buf_delete(self.bufnr, { force = true }) | |
end | |
return Debug | |
--- | |
File: /lua/codecompanion/strategies/chat/helpers.lua | |
--- | |
local base64 = require("codecompanion.utils.base64") | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
local path = require("plenary.path") | |
local get_node_text = vim.treesitter.get_node_text | |
local M = {} | |
---Format the given role without any separator | |
---@param role string | |
---@return string | |
function M.format_role(role) | |
if config.display.chat.show_header_separator then | |
role = vim.trim(role:gsub(config.display.chat.separator, "")) | |
end | |
return role | |
end | |
---Strip any references from the messages | |
---@param messages table | |
---@return table | |
function M.strip_references(messages) | |
local i = 1 | |
while messages[i] and messages[i]:sub(1, 1) == ">" do | |
table.remove(messages, i) | |
-- we do not increment i, since removing shifts everything down | |
end | |
return messages | |
end | |
---Get the keymaps from the slash commands | |
---@param slash_commands table | |
---@return table | |
function M.slash_command_keymaps(slash_commands) | |
local keymaps = {} | |
for k, v in pairs(slash_commands) do | |
if v.keymaps then | |
keymaps[k] = {} | |
keymaps[k].description = v.description | |
keymaps[k].callback = "keymaps." .. k | |
keymaps[k].modes = v.keymaps.modes | |
end | |
end | |
return keymaps | |
end | |
---Base64 encode the given image | |
---@param image table The image object containing the path and other metadata. | |
---@return {base64: string, mimetype: string}|string The base64 encoded image string | |
function M.encode_image(image) | |
local b64_content, b64_err = base64.encode(image.path) | |
if b64_err then | |
return b64_err | |
end | |
image.base64 = b64_content | |
if not image.mimetype then | |
image.mimetype = base64.get_mimetype(image.path) | |
end | |
return image | |
end | |
---Add an image to the chat buffer | |
---@param Chat CodeCompanion.Chat The chat instance | |
---@param image table The image object containing the path and other metadata | |
---@param opts table Options for adding the image | |
---@return nil | |
function M.add_image(Chat, image, opts) | |
opts = opts or {} | |
local id = "<image>" .. (image.id or image.path) .. "</image>" | |
Chat:add_message({ | |
role = opts.role or config.constants.USER_ROLE, | |
content = image.base64, | |
}, { reference = id, mimetype = image.mimetype, tag = "image", visible = false }) | |
Chat.references:add({ | |
bufnr = opts.bufnr or image.bufnr, | |
id = id, | |
path = image.path, | |
source = opts.source or "codecompanion.strategies.chat.slash_commands.image", | |
}) | |
end | |
---Get the range of two nodes | |
---@param start_node TSNode | |
---@param end_node TSNode | |
local function range_from_nodes(start_node, end_node) | |
local row, col = start_node:start() | |
local end_row, end_col = end_node:end_() | |
return { | |
lnum = row + 1, | |
end_lnum = end_row + 1, | |
col = col, | |
end_col = end_col, | |
} | |
end | |
---Extract symbols from a file using Tree-sitter | |
---@param filepath string The path to the file | |
---@param target_kinds? string[] Optional list of symbol kinds to include (default: all) | |
---@return table[]|nil symbols Array of symbols with name, kind, start_line, end_line | |
---@return string|nil content File content if successful | |
function M.extract_file_symbols(filepath, target_kinds) | |
local ft = vim.filetype.match({ filename = filepath }) | |
if not ft then | |
local base_name = vim.fs.basename(filepath) | |
local split_name = vim.split(base_name, "%.") | |
if #split_name > 1 then | |
local ext = split_name[#split_name] | |
if ext == "ts" then | |
ft = "typescript" | |
end | |
end | |
end | |
if not ft then | |
return nil, nil | |
end | |
local ok, content = pcall(function() | |
return path.new(filepath):read() | |
end) | |
if not ok then | |
return nil, nil | |
end | |
local query = vim.treesitter.query.get(ft, "symbols") | |
if not query then | |
return nil, content | |
end | |
local parser = vim.treesitter.get_string_parser(content, ft) | |
local tree = parser:parse()[1] | |
local symbols = {} | |
for _, matches, metadata in query:iter_matches(tree:root(), content) do | |
local match = vim.tbl_extend("force", {}, metadata) | |
for id, nodes in pairs(matches) do | |
local node = type(nodes) == "table" and nodes[1] or nodes | |
match = vim.tbl_extend("keep", match, { | |
[query.captures[id]] = { | |
metadata = metadata[id], | |
node = node, | |
}, | |
}) | |
end | |
local name_match = match.name or {} | |
local symbol_node = (match.symbol or match.type or {}).node | |
if not symbol_node then | |
goto continue | |
end | |
local start_node = (match.start or {}).node or symbol_node | |
local end_node = (match["end"] or {}).node or start_node | |
local kind = match.kind | |
-- Filter by target kinds if specified | |
if target_kinds and not vim.tbl_contains(target_kinds, kind) then | |
goto continue | |
end | |
local range = range_from_nodes(start_node, end_node) | |
local symbol_name = name_match.node and vim.trim(get_node_text(name_match.node, content)) or "<unknown>" | |
table.insert(symbols, { | |
name = symbol_name, | |
kind = kind, | |
start_line = range.lnum, | |
end_line = range.end_lnum, | |
-- Keep original format for symbols.lua compatibility | |
range = range, | |
}) | |
::continue:: | |
end | |
return symbols, content | |
end | |
return M | |
--- | |
File: /lua/codecompanion/strategies/chat/init.lua | |
--- | |
--============================================================================= | |
-- The Chat Buffer - Where all of the logic for conversing with an LLM sits | |
--============================================================================= | |
---@class CodeCompanion.Chat | |
---@field adapter CodeCompanion.Adapter The adapter to use for the chat | |
---@field agents CodeCompanion.Agent The agent that calls tools available to the user | |
---@field aug number The ID for the autocmd group | |
---@field bufnr integer The buffer number of the chat | |
---@field context table The context of the buffer that the chat was initiated from | |
---@field current_request table|nil The current request being executed | |
---@field current_tool table The current tool being executed | |
---@field cycle number Records the number of turn-based interactions (User -> LLM) that have taken place | |
---@field header_line number The line number of the user header that any Tree-sitter parsing should start from | |
---@field from_prompt_library? boolean Whether the chat was initiated from the prompt library | |
---@field header_ns integer The namespace for the virtual text that appears in the header | |
---@field id integer The unique identifier for the chat | |
---@field messages? table The messages in the chat buffer | |
---@field opts CodeCompanion.ChatArgs Store all arguments in this table | |
---@field parser vim.treesitter.LanguageTree The Markdown Tree-sitter parser for the chat buffer | |
---@field references CodeCompanion.Chat.References | |
---@field refs? table<CodeCompanion.Chat.Ref> References which are sent to the LLM e.g. buffers, slash command output | |
---@field settings? table The settings that are used in the adapter of the chat buffer | |
---@field subscribers table The subscribers to the chat buffer | |
---@field tokens? nil|number The number of tokens in the chat | |
---@field tools CodeCompanion.Chat.Tools Methods for handling interactions between the chat buffer and tools | |
---@field ui CodeCompanion.Chat.UI The UI of the chat buffer | |
---@field variables? CodeCompanion.Variables The variables available to the user | |
---@field watchers CodeCompanion.Watchers The buffer watcher instance | |
---@field yaml_parser vim.treesitter.LanguageTree The Yaml Tree-sitter parser for the chat buffer | |
---@class CodeCompanion.ChatArgs Arguments that can be injected into the chat | |
---@field adapter? CodeCompanion.Adapter The adapter used in this chat buffer | |
---@field auto_submit? boolean Automatically submit the chat when the chat buffer is created | |
---@field context? table Context of the buffer that the chat was initiated from | |
---@field from_prompt_library? boolean Whether the chat was initiated from the prompt library | |
---@field ignore_system_prompt? boolean Do not send the default system prompt with the request | |
---@field last_role? string The role of the last response in the chat buffer | |
---@field messages? table The messages to display in the chat buffer | |
---@field settings? table The settings that are used in the adapter of the chat buffer | |
---@field status? string The status of any running jobs in the chat buffe | |
---@field stop_context_insertion? boolean Stop any visual selection from being automatically inserted into the chat buffer | |
---@field tokens? table Total tokens spent in the chat buffer so far | |
local adapters = require("codecompanion.adapters") | |
local client = require("codecompanion.http") | |
local completion = require("codecompanion.providers.completion") | |
local config = require("codecompanion.config") | |
local hash = require("codecompanion.utils.hash") | |
local helpers = require("codecompanion.strategies.chat.helpers") | |
local keymaps = require("codecompanion.utils.keymaps") | |
local log = require("codecompanion.utils.log") | |
local schema = require("codecompanion.schema") | |
local util = require("codecompanion.utils") | |
local yaml = require("codecompanion.utils.yaml") | |
local api = vim.api | |
local get_node_text = vim.treesitter.get_node_text --[[@type function]] | |
local get_query = vim.treesitter.query.get --[[@type function]] | |
local CONSTANTS = { | |
AUTOCMD_GROUP = "codecompanion.chat", | |
STATUS_CANCELLING = "cancelling", | |
STATUS_ERROR = "error", | |
STATUS_SUCCESS = "success", | |
BLANK_DESC = "[No messages]", | |
} | |
local llm_role = config.strategies.chat.roles.llm | |
local user_role = config.strategies.chat.roles.user | |
--============================================================================= | |
-- Private methods | |
--============================================================================= | |
---Add updated content from the pins to the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@return nil | |
local function add_pins(chat) | |
local pins = vim | |
.iter(chat.refs) | |
:filter(function(ref) | |
return ref.opts.pinned | |
end) | |
:totable() | |
if vim.tbl_isempty(pins) then | |
return | |
end | |
for _, pin in ipairs(pins) do | |
-- Don't add the pin twice in the same cycle | |
local exists = false | |
vim.iter(chat.messages):each(function(msg) | |
if msg.opts and msg.opts.reference == pin.id and msg.cycle == chat.cycle then | |
exists = true | |
end | |
end) | |
if not exists then | |
util.fire("ChatPin", { bufnr = chat.bufnr, id = chat.id, pin_id = pin.id }) | |
require(pin.source) | |
.new({ Chat = chat }) | |
:output({ path = pin.path, bufnr = pin.bufnr, params = pin.params }, { pin = true }) | |
end | |
end | |
end | |
---Find a message in the table that has a specific tag | |
---@param id string | |
---@param messages table | |
---@return table|nil | |
local function find_tool_call(id, messages) | |
for _, msg in ipairs(messages) do | |
if msg.tool_call_id and msg.tool_call_id == id then | |
return msg | |
end | |
end | |
return nil | |
end | |
---Get the settings key at the current cursor position | |
---@param chat CodeCompanion.Chat | |
---@param opts? table | |
local function get_settings_key(chat, opts) | |
opts = vim.tbl_extend("force", opts or {}, { | |
lang = "yaml", | |
ignore_injections = false, | |
}) | |
local node = vim.treesitter.get_node(opts) | |
while node and node:type() ~= "block_mapping_pair" do | |
node = node:parent() | |
end | |
if not node then | |
return | |
end | |
local key_node = node:named_child(0) | |
local key_name = get_node_text(key_node, chat.bufnr) | |
return key_name, node | |
end | |
---Determine if a tag exists in the messages table | |
---@param tag string | |
---@param messages table | |
---@return boolean | |
local function has_tag(tag, messages) | |
return vim.tbl_contains( | |
vim.tbl_map(function(msg) | |
return msg.opts and msg.opts.tag | |
end, messages), | |
tag | |
) | |
end | |
---Are there any user messages in the chat buffer? | |
---@param chat CodeCompanion.Chat | |
---@return boolean | |
local function has_user_messages(chat) | |
local count = vim | |
.iter(chat.messages) | |
:filter(function(msg) | |
return msg.role == config.constants.USER_ROLE | |
end) | |
:totable() | |
if #count == 0 then | |
return false | |
end | |
return true | |
end | |
---Increment the cycle count in the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@return nil | |
local function increment_cycle(chat) | |
chat.cycle = chat.cycle + 1 | |
end | |
---Make an id from a string or table | |
---@param val string|table | |
---@return number | |
local function make_id(val) | |
return hash.hash(val) | |
end | |
---Set the editable text area. This allows us to scope the Tree-sitter queries to a specific area | |
---@param chat CodeCompanion.Chat | |
---@param modifier? number | |
---@return nil | |
local function set_text_editing_area(chat, modifier) | |
modifier = modifier or 0 | |
chat.header_line = api.nvim_buf_line_count(chat.bufnr) + modifier | |
end | |
---Ready the chat buffer for the next round of conversation | |
---@param chat CodeCompanion.Chat | |
---@param opts? table | |
---@return nil | |
local function ready_chat_buffer(chat, opts) | |
opts = opts or {} | |
if not opts.auto_submit and chat.last_role ~= config.constants.USER_ROLE then | |
increment_cycle(chat) | |
chat:add_buf_message({ role = config.constants.USER_ROLE, content = "" }) | |
set_text_editing_area(chat, -2) | |
chat.ui:display_tokens(chat.parser, chat.header_line) | |
chat.references:render() | |
chat.subscribers:process(chat) | |
end | |
-- If we're automatically responding to a tool output, we need to leave some | |
-- space for the LLM's response so we can then display the user prompt again | |
if opts.auto_submit then | |
chat:add_buf_message({ | |
role = config.constants.LLM_ROLE, | |
content = "\n\n", | |
opts = { visible = true }, | |
}) | |
end | |
log:info("Chat request finished") | |
chat:reset() | |
end | |
local _cached_settings = {} | |
---Parse the chat buffer for settings | |
---@param bufnr integer | |
---@param parser vim.treesitter.LanguageTree | |
---@param adapter? CodeCompanion.Adapter | |
---@return table | |
local function ts_parse_settings(bufnr, parser, adapter) | |
if _cached_settings[bufnr] then | |
return _cached_settings[bufnr] | |
end | |
-- If the user has disabled settings in the chat buffer, use the default settings | |
if not config.display.chat.show_settings then | |
if adapter then | |
_cached_settings[bufnr] = adapter:make_from_schema() | |
return _cached_settings[bufnr] | |
end | |
end | |
local settings = {} | |
local query = get_query("yaml", "chat") | |
local root = parser:parse()[1]:root() | |
local end_line = -1 | |
if adapter then | |
-- Account for the two YAML lines and the fact Tree-sitter is 0-indexed | |
end_line = vim.tbl_count(adapter.schema) + 2 - 1 | |
end | |
for _, matches, _ in query:iter_matches(root, bufnr, 0, end_line) do | |
local nodes = matches[1] | |
local node = type(nodes) == "table" and nodes[1] or nodes | |
local value = get_node_text(node, bufnr) | |
settings = yaml.decode(value) | |
break | |
end | |
if not settings then | |
log:error("Failed to parse settings in chat buffer") | |
return {} | |
end | |
return settings | |
end | |
---Parse the chat buffer for the last message | |
---@param chat CodeCompanion.Chat | |
---@param start_range number | |
---@return { content: string }|nil | |
local function ts_parse_messages(chat, start_range) | |
local query = get_query("markdown", "chat") | |
local tree = chat.parser:parse({ start_range - 1, -1 })[1] | |
local root = tree:root() | |
local content = {} | |
local last_role = nil | |
for id, node in query:iter_captures(root, chat.bufnr, start_range - 1, -1) do | |
if query.captures[id] == "role" then | |
last_role = helpers.format_role(get_node_text(node, chat.bufnr)) | |
elseif last_role == user_role and query.captures[id] == "content" then | |
table.insert(content, get_node_text(node, chat.bufnr)) | |
end | |
end | |
content = helpers.strip_references(content) -- If users send a blank message to the LLM, sometimes references are included | |
if not vim.tbl_isempty(content) then | |
return { content = vim.trim(table.concat(content, "\n\n")) } | |
end | |
return nil | |
end | |
---Parse the chat buffer for the last header | |
---@param chat CodeCompanion.Chat | |
---@return number|nil | |
local function ts_parse_headers(chat) | |
local query = get_query("markdown", "chat") | |
local tree = chat.parser:parse({ 0, -1 })[1] | |
local root = tree:root() | |
local last_match = nil | |
for id, node in query:iter_captures(root, chat.bufnr) do | |
if query.captures[id] == "role_only" then | |
local role = helpers.format_role(get_node_text(node, chat.bufnr)) | |
if role == user_role then | |
last_match = node | |
end | |
end | |
end | |
if last_match then | |
return last_match:range() | |
end | |
end | |
---Parse a section of the buffer for Markdown inline links. | |
---@param chat CodeCompanion.Chat The chat instance. | |
---@param start_range number The 1-indexed line number from where to start parsing. | |
local function ts_parse_images(chat, start_range) | |
local ts_query = vim.treesitter.query.parse( | |
"markdown_inline", | |
[[ | |
((inline_link) @link) | |
]] | |
) | |
local parser = vim.treesitter.get_parser(chat.bufnr, "markdown_inline") | |
local tree = parser:parse({ start_range, -1 })[1] | |
local root = tree:root() | |
local links = {} | |
for id, node in ts_query:iter_captures(root, chat.bufnr, start_range - 1, -1) do | |
local capture_name = ts_query.captures[id] | |
if capture_name == "link" then | |
local link_label_text = nil | |
local link_dest_text = nil | |
for child in node:iter_children() do | |
local child_type = child:type() | |
if child_type == "link_text" then | |
local text = vim.treesitter.get_node_text(child, chat.bufnr) | |
link_label_text = text | |
elseif child_type == "link_destination" then | |
local text = vim.treesitter.get_node_text(child, chat.bufnr) | |
link_dest_text = text | |
end | |
end | |
if link_label_text and link_dest_text then | |
table.insert(links, { text = link_label_text, path = link_dest_text }) | |
end | |
end | |
end | |
if vim.tbl_isempty(links) then | |
return nil | |
end | |
return links | |
end | |
---Parse the chat buffer for a code block | |
---returns the code block that the cursor is in or the last code block | |
---@param chat CodeCompanion.Chat | |
---@param cursor? table | |
---@return TSNode | nil | |
local function ts_parse_codeblock(chat, cursor) | |
local root = chat.parser:parse()[1]:root() | |
local query = get_query("markdown", "chat") | |
if query == nil then | |
return nil | |
end | |
local last_match = nil | |
for id, node in query:iter_captures(root, chat.bufnr, 0, -1) do | |
if query.captures[id] == "code" then | |
if cursor then | |
local start_row, start_col, end_row, end_col = node:range() | |
if cursor[1] >= start_row and cursor[1] <= end_row and cursor[2] >= start_col and cursor[2] <= end_col then | |
return node | |
end | |
end | |
last_match = node | |
end | |
end | |
return last_match | |
end | |
---Used to record the last chat buffer that was opened | |
---@type CodeCompanion.Chat|nil | |
---@diagnostic disable-next-line: missing-fields | |
local last_chat = {} | |
---Set the autocmds for the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@return nil | |
local function set_autocmds(chat) | |
local bufnr = chat.bufnr | |
api.nvim_create_autocmd("BufEnter", { | |
group = chat.aug, | |
buffer = bufnr, | |
desc = "Log the most recent chat buffer", | |
callback = function() | |
last_chat = chat | |
end, | |
}) | |
api.nvim_create_autocmd("CompleteDone", { | |
group = chat.aug, | |
buffer = bufnr, | |
callback = function() | |
local item = vim.v.completed_item | |
if item.user_data and item.user_data.type == "slash_command" then | |
-- Clear the word from the buffer | |
local row, col = unpack(api.nvim_win_get_cursor(0)) | |
api.nvim_buf_set_text(bufnr, row - 1, col - #item.word, row - 1, col, { "" }) | |
completion.slash_commands_execute(item.user_data, chat) | |
end | |
end, | |
}) | |
if config.display.chat.show_settings then | |
api.nvim_create_autocmd("CursorMoved", { | |
group = chat.aug, | |
buffer = bufnr, | |
desc = "Show settings information in the CodeCompanion chat buffer", | |
callback = function() | |
local key_name, node = get_settings_key(chat) | |
if not key_name or not node then | |
vim.diagnostic.set(config.INFO_NS, chat.bufnr, {}) | |
return | |
end | |
local key_schema = chat.adapter.schema[key_name] | |
if key_schema and key_schema.desc then | |
local lnum, col, end_lnum, end_col = node:range() | |
local diagnostic = { | |
lnum = lnum, | |
col = col, | |
end_lnum = end_lnum, | |
end_col = end_col, | |
severity = vim.diagnostic.severity.INFO, | |
message = key_schema.desc, | |
} | |
vim.diagnostic.set(config.INFO_NS, chat.bufnr, { diagnostic }) | |
end | |
end, | |
}) | |
-- Validate the settings | |
api.nvim_create_autocmd("InsertLeave", { | |
group = chat.aug, | |
buffer = bufnr, | |
desc = "Parse the settings in the CodeCompanion chat buffer for any errors", | |
callback = function() | |
local settings = ts_parse_settings(bufnr, chat.yaml_parser, chat.adapter) | |
local errors = schema.validate(chat.adapter.schema, settings, chat.adapter) | |
local node = settings.__ts_node | |
local items = {} | |
if errors and node then | |
for child in node:iter_children() do | |
assert(child:type() == "block_mapping_pair") | |
local key = get_node_text(child:named_child(0), chat.bufnr) | |
if errors[key] then | |
local lnum, col, end_lnum, end_col = child:range() | |
table.insert(items, { | |
lnum = lnum, | |
col = col, | |
end_lnum = end_lnum, | |
end_col = end_col, | |
severity = vim.diagnostic.severity.ERROR, | |
message = errors[key], | |
}) | |
end | |
end | |
end | |
vim.diagnostic.set(config.ERROR_NS, chat.bufnr, items) | |
end, | |
}) | |
end | |
end | |
--============================================================================= | |
-- Public methods | |
--============================================================================= | |
---Methods that are available outside of CodeCompanion | |
---@type table<CodeCompanion.Chat> | |
local chatmap = {} | |
---@type table | |
_G.codecompanion_buffers = {} | |
---@class CodeCompanion.Chat | |
local Chat = {} | |
---@param args CodeCompanion.ChatArgs | |
---@return CodeCompanion.Chat | |
function Chat.new(args) | |
local id = math.random(10000000) | |
log:trace("Chat created with ID %d", id) | |
local self = setmetatable({ | |
context = args.context, | |
cycle = 1, | |
header_line = 1, | |
from_prompt_library = args.from_prompt_library or false, | |
id = id, | |
last_role = args.last_role or config.constants.USER_ROLE, | |
messages = args.messages or {}, | |
opts = args, | |
refs = {}, | |
status = "", | |
create_buf = function() | |
local bufnr = api.nvim_create_buf(false, true) | |
api.nvim_buf_set_name(bufnr, string.format("[CodeCompanion] %d", id)) | |
vim.bo[bufnr].filetype = "codecompanion" | |
return bufnr | |
end, | |
_chat_has_reasoning = false, | |
_tool_output_header_printed = false, | |
_tool_output_has_llm_response = false, | |
}, { __index = Chat }) | |
self.bufnr = self.create_buf() | |
self.aug = api.nvim_create_augroup(CONSTANTS.AUTOCMD_GROUP .. ":" .. self.bufnr, { | |
clear = false, | |
}) | |
-- Assign the parsers to the chat object for performance | |
local ok, parser, yaml_parser | |
ok, parser = pcall(vim.treesitter.get_parser, self.bufnr, "markdown") | |
if not ok then | |
return log:error("Could not find the Markdown Tree-sitter parser") | |
end | |
self.parser = parser | |
if config.display.chat.show_settings then | |
ok, yaml_parser = pcall(vim.treesitter.get_parser, self.bufnr, "yaml", { ignore_injections = false }) | |
if not ok then | |
return log:error("Could not find the Yaml Tree-sitter parser") | |
end | |
self.yaml_parser = yaml_parser | |
end | |
self.references = require("codecompanion.strategies.chat.references").new({ chat = self }) | |
self.subscribers = require("codecompanion.strategies.chat.subscribers").new() | |
self.agents = require("codecompanion.strategies.chat.agents").new({ bufnr = self.bufnr, messages = self.messages }) | |
self.tools = require("codecompanion.strategies.chat.tools").new({ chat = self }) | |
self.watchers = require("codecompanion.strategies.chat.watchers").new() | |
self.variables = require("codecompanion.strategies.chat.variables").new() | |
table.insert(_G.codecompanion_buffers, self.bufnr) | |
chatmap[self.bufnr] = { | |
name = "Chat " .. vim.tbl_count(chatmap) + 1, | |
description = CONSTANTS.BLANK_DESC, | |
strategy = "chat", | |
chat = self, | |
} | |
if args.adapter and adapters.resolved(args.adapter) then | |
self.adapter = args.adapter | |
else | |
self.adapter = adapters.resolve(args.adapter or config.strategies.chat.adapter) | |
end | |
if not self.adapter then | |
return log:error("No adapter found") | |
end | |
util.fire("ChatAdapter", { | |
adapter = adapters.make_safe(self.adapter), | |
bufnr = self.bufnr, | |
id = self.id, | |
}) | |
util.fire("ChatModel", { bufnr = self.bufnr, id = self.id, model = self.adapter.schema.model.default }) | |
self:apply_settings(schema.get_default(self.adapter, args.settings)) | |
self.ui = require("codecompanion.strategies.chat.ui").new({ | |
adapter = self.adapter, | |
id = self.id, | |
bufnr = self.bufnr, | |
roles = { user = user_role, llm = llm_role }, | |
settings = self.settings, | |
}) | |
if args.messages then | |
self.messages = args.messages | |
end | |
self.close_last_chat() | |
self.ui:open():render(self.context, self.messages, args) | |
-- Set the header line for the chat buffer | |
if args.messages and vim.tbl_count(args.messages) > 0 then | |
---@cast self CodeCompanion.Chat | |
local header_line = ts_parse_headers(self) | |
self.header_line = header_line and (header_line + 1) or 1 | |
end | |
if vim.tbl_isempty(self.messages) then | |
self.ui:set_intro_msg() | |
end | |
if config.strategies.chat.keymaps then | |
keymaps | |
.new({ | |
bufnr = self.bufnr, | |
callbacks = require("codecompanion.strategies.chat.keymaps"), | |
data = self, | |
keymaps = config.strategies.chat.keymaps, | |
}) | |
:set() | |
end | |
local slash_command_keymaps = helpers.slash_command_keymaps(config.strategies.chat.slash_commands) | |
if vim.tbl_count(slash_command_keymaps) > 0 then | |
keymaps | |
.new({ | |
bufnr = self.bufnr, | |
callbacks = require("codecompanion.strategies.chat.slash_commands.keymaps"), | |
data = self, | |
keymaps = slash_command_keymaps, | |
}) | |
:set() | |
end | |
---@cast self CodeCompanion.Chat | |
self:add_system_prompt() | |
set_autocmds(self) | |
last_chat = self | |
for _, tool_name in pairs(config.strategies.chat.tools.opts.default_tools or {}) do | |
local tool_config = config.strategies.chat.tools[tool_name] | |
if tool_config ~= nil then | |
self.tools:add(tool_name, tool_config) | |
elseif config.strategies.chat.tools.groups[tool_name] ~= nil then | |
self.tools:add_group(tool_name, config.strategies.chat.tools) | |
end | |
end | |
util.fire("ChatCreated", { bufnr = self.bufnr, from_prompt_library = self.from_prompt_library, id = self.id }) | |
if args.auto_submit then | |
self:submit() | |
end | |
return self ---@type CodeCompanion.Chat | |
end | |
---Format and apply settings to the chat buffer | |
---@param settings? table | |
---@return nil | |
function Chat:apply_settings(settings) | |
self.settings = settings or schema.get_default(self.adapter) | |
if not config.display.chat.show_settings then | |
_cached_settings[self.bufnr] = self.settings | |
end | |
end | |
---Set a model in the chat buffer | |
---@param model string | |
---@return self | |
function Chat:apply_model(model) | |
if _cached_settings[self.bufnr] then | |
_cached_settings[self.bufnr].model = model | |
end | |
self.adapter.schema.model.default = model | |
self.adapter = adapters.set_model(self.adapter) | |
return self | |
end | |
---The source to provide the model entries for completion (cmp only) | |
---@param callback fun(request: table) | |
---@return nil | |
function Chat:complete_models(callback) | |
local items = {} | |
local cursor = api.nvim_win_get_cursor(0) | |
local key_name, node = get_settings_key(self, { pos = { cursor[1] - 1, 1 } }) | |
if not key_name or not node then | |
callback({ items = items, isIncomplete = false }) | |
return | |
end | |
local key_schema = self.adapter.schema[key_name] | |
if key_schema.type == "enum" then | |
local choices = key_schema.choices | |
if type(choices) == "function" then | |
choices = choices(self.adapter) | |
end | |
for _, choice in ipairs(choices) do | |
table.insert(items, { | |
label = choice, | |
kind = require("cmp").lsp.CompletionItemKind.Keyword, | |
}) | |
end | |
end | |
callback({ items = items, isIncomplete = false }) | |
end | |
---Set the system prompt in the chat buffer | |
---@params prompt? string | |
---@params opts? table | |
---@return CodeCompanion.Chat | |
function Chat:add_system_prompt(prompt, opts) | |
if self.opts and self.opts.ignore_system_prompt then | |
return self | |
end | |
opts = opts or { visible = false, tag = "from_config" } | |
-- Don't add the same system prompt twice | |
if has_tag(opts.tag, self.messages) then | |
return self | |
end | |
-- Get the index of the last system prompt | |
local index | |
if not opts.index then | |
for i = #self.messages, 1, -1 do | |
if self.messages[i].role == config.constants.SYSTEM_ROLE then | |
index = i + 1 | |
break | |
end | |
end | |
end | |
prompt = prompt or config.opts.system_prompt | |
if prompt ~= "" then | |
if type(prompt) == "function" then | |
prompt = prompt({ | |
adapter = self.adapter, | |
language = config.opts.language, | |
}) | |
end | |
local system_prompt = { | |
role = config.constants.SYSTEM_ROLE, | |
content = prompt, | |
} | |
system_prompt.id = make_id(system_prompt) | |
system_prompt.cycle = self.cycle | |
system_prompt.opts = opts | |
table.insert(self.messages, index or 1, system_prompt) | |
end | |
return self | |
end | |
---Toggle the system prompt in the chat buffer | |
---@return nil | |
function Chat:toggle_system_prompt() | |
local has_system_prompt = vim.tbl_contains( | |
vim.tbl_map(function(msg) | |
return msg.opts.tag | |
end, self.messages), | |
"from_config" | |
) | |
if has_system_prompt then | |
self:remove_tagged_message("from_config") | |
util.notify("Removed system prompt") | |
else | |
self:add_system_prompt() | |
util.notify("Added system prompt") | |
end | |
end | |
---Remove a message with a given tag | |
---@param tag string | |
---@return nil | |
function Chat:remove_tagged_message(tag) | |
self.messages = vim | |
.iter(self.messages) | |
:filter(function(msg) | |
if msg.opts and msg.opts.tag == tag then | |
return false | |
end | |
return true | |
end) | |
:totable() | |
end | |
---Add a message to the message table | |
---@param data { role: string, content: string, tool_calls?: table } | |
---@param opts? table Options for the message | |
---@return CodeCompanion.Chat | |
function Chat:add_message(data, opts) | |
opts = opts or { visible = true } | |
if opts.visible == nil then | |
opts.visible = true | |
end | |
local message = { | |
role = data.role, | |
content = data.content, | |
tool_calls = data.tool_calls, | |
} | |
message.id = make_id(message) | |
message.cycle = self.cycle | |
message.opts = opts | |
if opts.index then | |
table.insert(self.messages, opts.index, message) | |
else | |
table.insert(self.messages, message) | |
end | |
return self | |
end | |
---Apply any tools or variables that a user has tagged in their message | |
---@param message table | |
---@return nil | |
function Chat:replace_vars_and_tools(message) | |
if self.agents:parse(self, message) then | |
message.content = self.agents:replace(message.content) | |
end | |
if self.variables:parse(self, message) then | |
message.content = self.variables:replace(message.content, self.context.bufnr) | |
end | |
end | |
---Submit the chat buffer's contents to the LLM | |
---@param opts? table | |
---@return nil | |
function Chat:submit(opts) | |
if self.current_request then | |
return log:debug("Chat request already in progress") | |
end | |
opts = opts or {} | |
if opts.callback then | |
opts.callback() | |
end | |
local bufnr = self.bufnr | |
if opts.auto_submit then | |
self.watchers:check_for_changes(self) | |
self:add_message({ | |
role = config.constants.USER_ROLE, | |
content = "I've shared the output from the tool/function call with you.", | |
}, { visible = false }) | |
else | |
local message = ts_parse_messages(self, self.header_line) | |
if not message and not has_user_messages(self) then | |
return log:warn("No messages to submit") | |
end | |
self.watchers:check_for_changes(self) | |
-- Allow users to send a blank message to the LLM | |
if not opts.regenerate then | |
local chat_opts = config.strategies.chat.opts | |
if message and message.content and chat_opts and chat_opts.prompt_decorator then | |
message.content = chat_opts.prompt_decorator(message.content, adapters.make_safe(self.adapter), self.context) | |
end | |
self:add_message({ | |
role = config.constants.USER_ROLE, | |
content = (message and message.content or config.strategies.chat.opts.blank_prompt), | |
}) | |
end | |
-- NOTE: There are instances when submit is called with no user message. Such | |
-- as in the case of tools auto-submitting responses. References should be | |
-- excluded and we can do this by checking for user messages. | |
if message then | |
message = self.references:clear(self.messages[#self.messages]) | |
self:replace_vars_and_tools(message) | |
self:check_images(message) | |
self:check_references() | |
add_pins(self) | |
end | |
-- Check if the user has manually overridden the adapter | |
if vim.g.codecompanion_adapter and self.adapter.name ~= vim.g.codecompanion_adapter then | |
self.adapter = adapters.resolve(config.adapters[vim.g.codecompanion_adapter]) | |
end | |
if not config.display.chat.auto_scroll then | |
vim.cmd("stopinsert") | |
end | |
self.ui:lock_buf() | |
set_text_editing_area(self, 2) -- this accounts for the LLM header | |
end | |
local settings = ts_parse_settings(bufnr, self.yaml_parser, self.adapter) | |
self:apply_settings(settings) | |
local mapped_settings = self.adapter:map_schema_to_params(settings) | |
local payload = { | |
messages = self.adapter:map_roles(vim.deepcopy(self.messages)), | |
tools = (not vim.tbl_isempty(self.tools.schemas) and { self.tools.schemas } or {}), | |
} | |
log:trace("Settings:\n%s", mapped_settings) | |
log:trace("Messages:\n%s", self.messages) | |
log:trace("Tools:\n%s", payload.tools) | |
log:info("Chat request started") | |
local output = {} | |
local tools = {} | |
self.current_request = client.new({ adapter = mapped_settings }):request(payload, { | |
---@param err { message: string, stderr: string } | |
---@param data table | |
---@param adapter CodeCompanion.Adapter The modified adapter from the http client | |
callback = function(err, data, adapter) | |
if err and err.stderr ~= "{}" then | |
self.status = CONSTANTS.STATUS_ERROR | |
log:error("Error: %s", err.stderr) | |
return self:done(output) | |
end | |
if data then | |
if adapter.features.tokens then | |
local tokens = self.adapter.handlers.tokens(adapter, data) | |
if tokens then | |
self.ui.tokens = tokens | |
end | |
end | |
local result = self.adapter.handlers.chat_output(adapter, data, tools) | |
if result and result.status then | |
self.status = result.status | |
if self.status == CONSTANTS.STATUS_SUCCESS then | |
if result.output.role then | |
result.output.role = config.constants.LLM_ROLE | |
end | |
table.insert(output, result.output.content) | |
self:add_buf_message(result.output) | |
if result.output.content ~= "" and not self._tool_output_has_llm_response then | |
self._tool_output_has_llm_response = true | |
end | |
elseif self.status == CONSTANTS.STATUS_ERROR then | |
log:error("Error: %s", result.output) | |
return self:done(output) | |
end | |
end | |
end | |
end, | |
done = function() | |
self:done(output, tools) | |
end, | |
}, { bufnr = bufnr, strategy = "chat" }) | |
util.fire("ChatSubmitted", { bufnr = self.bufnr, id = self.id }) | |
end | |
---Method to fire when all the tools are done | |
---@param opts? table | |
---@return nil | |
function Chat:tools_done(opts) | |
opts = opts or {} | |
return ready_chat_buffer(self, opts) | |
end | |
---Method to call after the response from the LLM is received | |
---@param output? table The output from the LLM | |
---@param tools? table The tools from the LLM | |
---@return nil | |
function Chat:done(output, tools) | |
self.current_request = nil | |
-- Commonly, a status may not be set if the message exceeds a token limit | |
if not self.status or self.status == "" then | |
return self:reset() | |
end | |
local has_tools = tools and not vim.tbl_isempty(tools) | |
local has_output = output and not vim.tbl_isempty(output) | |
-- Handle LLM output text | |
if has_output then | |
local content = vim.trim(table.concat(output or {}, "")) -- No idea why the LSP freaks out that this isn't a table | |
if content ~= "" then | |
self:add_message({ | |
role = config.constants.LLM_ROLE, | |
content = content, | |
}) | |
end | |
end | |
if has_tools then | |
tools = self.adapter.handlers.tools.format_tool_calls(self.adapter, tools) | |
self:add_message({ | |
role = config.constants.LLM_ROLE, | |
tool_calls = tools, | |
opts = { | |
visible = false, | |
}, | |
}) | |
return self.agents:execute(self, tools) | |
end | |
ready_chat_buffer(self) | |
end | |
---Add a reference to the chat buffer (Useful for user's adding custom Slash Commands) | |
---@param data { role: string, content: string } | |
---@param source string | |
---@param id string | |
---@param opts? table Options for the message | |
function Chat:add_reference(data, source, id, opts) | |
opts = opts or { reference = id, visible = false } | |
self.references:add({ source = source, id = id }) | |
self:add_message(data, opts) | |
end | |
---Check if there are any images in the chat buffer | |
---@param message table | |
---@return nil | |
function Chat:check_images(message) | |
local images = ts_parse_images(self, self.header_line) | |
if not images then | |
return | |
end | |
for _, image in ipairs(images) do | |
local encoded_image = helpers.encode_image(image) | |
if type(encoded_image) == "string" then | |
log:warn("Could not encode image: %s", encoded_image) | |
else | |
helpers.add_image(self, encoded_image) | |
-- Replace the image link in the message with "image" | |
local to_remove = string.format("[Image](%s)", image.path) | |
message.content = vim.trim(message.content:gsub(vim.pesc(to_remove), "image")) | |
end | |
end | |
end | |
---Reconcile the references table to the references in the chat buffer | |
---@return nil | |
function Chat:check_references() | |
local refs_in_chat = self.references:get_from_chat() | |
if vim.tbl_isempty(refs_in_chat) and vim.tbl_isempty(self.refs) then | |
return | |
end | |
local function expand_group_ref(group_name) | |
local group_config = self.agents.tools_config.groups[group_name] or {} | |
return vim.tbl_map(function(tool) | |
return "<tool>" .. tool .. "</tool>" | |
end, group_config.tools or {}) | |
end | |
local groups_in_chat = {} | |
for _, id in ipairs(refs_in_chat) do | |
local group_name = id:match("<group>(.*)</group>") | |
if group_name and vim.trim(group_name) ~= "" then | |
table.insert(groups_in_chat, group_name) | |
end | |
end | |
-- Populate the refs_in_chat with tool refs from groups | |
vim.iter(groups_in_chat):each(function(group_name) | |
vim.list_extend(refs_in_chat, expand_group_ref(group_name)) | |
end) | |
-- Fetch references that exist on the chat object but not in the buffer | |
local to_remove = vim | |
.iter(self.refs) | |
:filter(function(ref) | |
return not vim.tbl_contains(refs_in_chat, ref.id) | |
end) | |
:map(function(ref) | |
return ref.id | |
end) | |
:totable() | |
if vim.tbl_isempty(to_remove) then | |
return | |
end | |
local groups_to_remove = vim.tbl_filter(function(id) | |
return id:match("<group>(.*)</group>") | |
end, to_remove) | |
-- Extend to_remove with tools in the groups | |
vim.iter(groups_to_remove):each(function(group_name) | |
vim.list_extend(to_remove, expand_group_ref(group_name)) | |
end) | |
-- Remove them from the messages table | |
self.messages = vim | |
.iter(self.messages) | |
:filter(function(msg) | |
if msg.opts and msg.opts.reference and vim.tbl_contains(to_remove, msg.opts.reference) then | |
return false | |
end | |
return true | |
end) | |
:totable() | |
-- And from the refs table | |
self.refs = vim | |
.iter(self.refs) | |
:filter(function(ref) | |
return not vim.tbl_contains(to_remove, ref.id) | |
end) | |
:totable() | |
-- Clear any tool's schemas | |
local schemas_to_keep = {} | |
local tools_in_use_to_keep = {} | |
for id, schema in pairs(self.tools.schemas) do | |
if not vim.tbl_contains(to_remove, id) then | |
schemas_to_keep[id] = schema | |
local tool_name = id:match("<tool>(.*)</tool>") | |
if tool_name and self.tools.in_use[tool_name] then | |
tools_in_use_to_keep[tool_name] = true | |
end | |
else | |
log:debug("Removing tool schema and usage flag for ID: %s", id) -- Optional logging | |
end | |
end | |
self.tools.schemas = schemas_to_keep | |
self.tools.in_use = tools_in_use_to_keep | |
end | |
---Regenerate the response from the LLM | |
---@return nil | |
function Chat:regenerate() | |
if self.messages[#self.messages].role == config.constants.LLM_ROLE then | |
table.remove(self.messages, #self.messages) | |
self:add_buf_message({ role = config.constants.USER_ROLE, content = "_Regenerating response..._" }) | |
self:submit({ regenerate = true }) | |
end | |
end | |
---Stop streaming the response from the LLM | |
---@return nil | |
function Chat:stop() | |
local job | |
self.status = CONSTANTS.STATUS_CANCELLING | |
util.fire("ChatStopped", { bufnr = self.bufnr, id = self.id }) | |
if self.current_tool then | |
job = self.current_tool | |
self.current_tool = nil | |
_G.codecompanion_cancel_tool = true | |
pcall(function() | |
job:shutdown() | |
end) | |
end | |
if self.current_request then | |
job = self.current_request | |
self.current_request = nil | |
if job then | |
pcall(function() | |
job:shutdown() | |
end) | |
end | |
self.adapter.handlers.on_exit(self.adapter) | |
end | |
self.subscribers:stop() | |
vim.schedule(function() | |
log:debug("Chat request cancelled") | |
self:done() | |
end) | |
end | |
---Close the current chat buffer | |
---@return nil | |
function Chat:close() | |
if self.current_request then | |
self:stop() | |
end | |
if last_chat and last_chat.bufnr == self.bufnr then | |
last_chat = nil | |
end | |
table.remove( | |
_G.codecompanion_buffers, | |
vim.iter(_G.codecompanion_buffers):enumerate():find(function(_, v) | |
return v == self.bufnr | |
end) | |
) | |
chatmap[self.bufnr] = nil | |
api.nvim_buf_delete(self.bufnr, { force = true }) | |
if self.aug then | |
api.nvim_clear_autocmds({ group = self.aug }) | |
end | |
if self.ui.aug then | |
api.nvim_clear_autocmds({ group = self.ui.aug }) | |
end | |
util.fire("ChatClosed", { bufnr = self.bufnr, id = self.id }) | |
util.fire("ChatAdapter", { bufnr = self.bufnr, id = self.id, adapter = nil }) | |
util.fire("ChatModel", { bufnr = self.bufnr, id = self.id, model = nil }) | |
self = nil | |
end | |
---Add a message directly to the chat buffer. This will be visible to the user | |
---@param data table | |
---@param opts? table | |
function Chat:add_buf_message(data, opts) | |
assert(type(data) == "table", "data must be a table") | |
local lines = {} | |
local bufnr = self.bufnr | |
local new_response = false | |
local function write(text) | |
for _, t in ipairs(vim.split(text, "\n", { plain = true, trimempty = false })) do | |
table.insert(lines, t) | |
end | |
end | |
-- Add a new header to the chat buffer | |
local function new_role() | |
new_response = true | |
self.last_role = data.role | |
table.insert(lines, "") | |
table.insert(lines, "") | |
self.ui:set_header(lines, config.strategies.chat.roles[data.role]) | |
end | |
-- Add data to the chat buffer | |
local function append_data() | |
-- Tool output | |
if opts and opts.tag == "tool_output" then | |
if not self._tool_output_header_printed then | |
self._tool_output_header_printed = true | |
if self._tool_output_has_llm_response then | |
table.insert(lines, "") | |
table.insert(lines, "") | |
end | |
table.insert(lines, "### Tool Output") | |
end | |
table.insert(lines, "") | |
return write(data.content or "") | |
end | |
-- Reasoning output | |
if data.reasoning then | |
if not self._chat_has_reasoning then | |
table.insert(lines, "### Reasoning") | |
table.insert(lines, "") | |
end | |
self._chat_has_reasoning = true | |
write(data.reasoning) | |
end | |
-- Regular output | |
if data.content then | |
if self._chat_has_reasoning then | |
self._chat_has_reasoning = false -- LLMs *should* do reasoning first then output after | |
table.insert(lines, "") | |
table.insert(lines, "") | |
table.insert(lines, "### Response") | |
table.insert(lines, "") | |
end | |
write(data.content) | |
end | |
end | |
local function update_buffer() | |
self.ui:unlock_buf() | |
local last_line, last_column, line_count = self.ui:last() | |
if opts and opts.insert_at then | |
last_line = opts.insert_at | |
last_column = 0 | |
end | |
local cursor_moved = api.nvim_win_get_cursor(0)[1] == line_count | |
api.nvim_buf_set_text(bufnr, last_line, last_column, last_line, last_column, lines) | |
if new_response then | |
self.ui:render_headers() | |
end | |
if self.last_role ~= config.constants.USER_ROLE then | |
self.ui:lock_buf() | |
end | |
if config.display.chat.auto_scroll then | |
if cursor_moved and self.ui:is_active() then | |
self.ui:follow() | |
elseif not self.ui:is_active() then | |
self.ui:follow() | |
end | |
end | |
end | |
-- Handle a new role | |
if (data.role and data.role ~= self.last_role) or (opts and opts.force_role) then | |
new_role() | |
end | |
-- Append the output from the LLM | |
if data.content or data.reasoning then | |
append_data() | |
update_buffer() | |
end | |
end | |
---Add the output from a tool to the message history and a message to the UI | |
---@param tool table The Tool that was executed | |
---@param for_llm string The output to share with the LLM | |
---@param for_user? string The output to share with the user. If empty will use the LLM's output | |
---@return nil | |
function Chat:add_tool_output(tool, for_llm, for_user) | |
local tool_call = tool.function_call | |
log:debug("Tool output: %s", tool_call) | |
local output = self.adapter.handlers.tools.output_response(self.adapter, tool_call, for_llm) | |
output.cycle = self.cycle | |
output.id = make_id({ role = output.role, content = output.content }) | |
output.opts = vim.tbl_extend("force", output.opts or {}, { | |
tag = "tool_output", | |
visible = true, | |
}) | |
local existing = find_tool_call(tool_call.id, self.messages) | |
if existing then | |
existing.content = existing.content .. "\n\n" .. output.content | |
else | |
table.insert(self.messages, output) | |
end | |
-- Allow tools to pass in an empty string to end the processing | |
if for_user == "" then | |
return | |
end | |
-- Update the contents of the chat buffer | |
for_user = for_user or for_llm | |
self:add_buf_message({ | |
role = config.constants.LLM_ROLE, | |
content = for_user, | |
}, { tag = "tool_output" }) | |
end | |
---When a request has finished, reset the chat buffer | |
---@return nil | |
function Chat:reset() | |
self._chat_has_reasoning = false | |
self._tool_output_header_printed = false | |
self._tool_output_has_llm_response = false | |
self.status = "" | |
self.ui:unlock_buf() | |
end | |
---Get currently focused code block or the last one in the chat buffer | |
---@return TSNode | nil | |
function Chat:get_codeblock() | |
local cursor = api.nvim_win_get_cursor(0) | |
return ts_parse_codeblock(self, cursor) | |
end | |
---Clear the chat buffer | |
---@return nil | |
function Chat:clear() | |
self.cycle = 1 | |
self.header_line = 1 | |
self.messages = {} | |
self.refs = {} | |
self.tools:clear() | |
log:trace("Clearing chat buffer") | |
self.ui:render(self.context, self.messages, self.opts):set_intro_msg() | |
self:add_system_prompt() | |
util.fire("ChatCleared", { bufnr = self.bufnr, id = self.id }) | |
end | |
---Display the chat buffer's settings and messages | |
function Chat:debug() | |
if vim.tbl_isempty(self.messages) then | |
return | |
end | |
return ts_parse_settings(self.bufnr, self.yaml_parser, self.adapter), self.messages | |
end | |
---Returns the chat object(s) based on the buffer number | |
---@param bufnr? integer | |
---@return CodeCompanion.Chat|table | |
function Chat.buf_get_chat(bufnr) | |
if not bufnr then | |
return vim | |
.iter(pairs(chatmap)) | |
:map(function(_, v) | |
return v | |
end) | |
:totable() | |
end | |
if bufnr == 0 then | |
bufnr = api.nvim_get_current_buf() | |
end | |
return chatmap[bufnr].chat | |
end | |
---Returns the last chat that was visible | |
---@return CodeCompanion.Chat|nil | |
function Chat.last_chat() | |
if not last_chat or vim.tbl_isempty(last_chat) then | |
return nil | |
end | |
return last_chat | |
end | |
---Close the last chat buffer | |
---@return nil | |
function Chat.close_last_chat() | |
if last_chat and not vim.tbl_isempty(last_chat) then | |
if last_chat.ui:is_visible() then | |
last_chat.ui:hide() | |
end | |
end | |
end | |
return Chat | |
--- | |
File: /lua/codecompanion/strategies/chat/references.lua | |
--- | |
--[[ | |
Handle the references that are shared with the chat buffer from sources such as | |
Slash Commands or variables. References are displayed back to the user via | |
the chat buffer, using block quotes and lists. | |
--]] | |
local config = require("codecompanion.config") | |
local helpers = require("codecompanion.strategies.chat.helpers") | |
local api = vim.api | |
local user_role = config.strategies.chat.roles.user | |
local icons_path = config.display.chat.icons | |
local icons = { | |
pinned = icons_path.pinned_buffer, | |
watched = icons_path.watched_buffer, | |
} | |
local allowed_pins = { | |
"<buf>", | |
"<file>", | |
} | |
local allowed_watchers = { | |
"<buf>", | |
} | |
---Parse the chat buffer to find where to add the references | |
---@param chat CodeCompanion.Chat | |
---@return table|nil | |
local function ts_parse_buffer(chat) | |
local query = vim.treesitter.query.get("markdown", "reference") | |
local tree = chat.parser:parse({ chat.header_line - 1, -1 })[1] | |
local root = tree:root() | |
-- Check if there are any references already in the chat buffer | |
local refs | |
for id, node in query:iter_captures(root, chat.bufnr, chat.header_line - 1, -1) do | |
if query.captures[id] == "refs" then | |
refs = node | |
end | |
end | |
if refs and not vim.tbl_isempty(chat.refs) then | |
local start_row, _, end_row, _ = refs:range() | |
return { | |
capture = "refs", | |
start_row = start_row + 2, | |
end_row = end_row + 1, | |
} | |
end | |
-- If not, check if there is a heading to add the references below | |
local role | |
local role_node | |
for id, node in query:iter_captures(root, chat.bufnr, chat.header_line - 1, -1) do | |
if query.captures[id] == "role" then | |
role = vim.treesitter.get_node_text(node, chat.bufnr) | |
role_node = node | |
end | |
end | |
role = helpers.format_role(role) | |
if role_node and role == user_role then | |
local start_row, _, end_row, _ = role_node:range() | |
return { | |
capture = "role", | |
start_row = start_row + 1, | |
end_row = end_row + 1, | |
} | |
end | |
return nil | |
end | |
---Add a reference to the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@param ref CodeCompanion.Chat.Ref | |
---@param row integer | |
local function add(chat, ref, row) | |
if not ref.opts.visible then | |
return | |
end | |
local lines = {} | |
-- Check if this reference has special options and format accordingly | |
local ref_text | |
if ref.opts and ref.opts.pinned then | |
ref_text = string.format("> - %s%s", icons.pinned, ref.id) | |
elseif ref.opts and ref.opts.watched then | |
ref_text = string.format("> - %s%s", icons.watched, ref.id) | |
else | |
ref_text = string.format("> - %s", ref.id) | |
end | |
table.insert(lines, ref_text) | |
if vim.tbl_count(chat.refs) == 1 then | |
table.insert(lines, 1, "> Context:") | |
table.insert(lines, "") | |
end | |
local was_locked = not vim.bo[chat.bufnr].modifiable | |
if was_locked then | |
chat.ui:unlock_buf() | |
end | |
api.nvim_buf_set_lines(chat.bufnr, row, row, false, lines) | |
if was_locked then | |
chat.ui:lock_buf() | |
end | |
end | |
---@class CodeCompanion.Chat.References | |
---@field Chat CodeCompanion.Chat | |
local References = {} | |
---@class CodeCompanion.Chat.RefsArgs | |
---@field chat CodeCompanion.Chat | |
---@param args CodeCompanion.Chat.RefsArgs | |
function References.new(args) | |
local self = setmetatable({ | |
Chat = args.chat, | |
}, { __index = References }) | |
return self | |
end | |
---Add a reference to the chat buffer | |
---@param ref CodeCompanion.Chat.Ref | |
---@return nil | |
function References:add(ref) | |
if not ref or not config.display.chat.show_references then | |
return self | |
end | |
if ref then | |
if not ref.opts then | |
ref.opts = {} | |
end | |
-- Ensure both properties exist with defaults | |
ref.opts.pinned = ref.opts.pinned or false | |
ref.opts.watched = ref.opts.watched or false | |
ref.opts.visible = ref.opts.visible | |
if ref.opts.visible == nil then | |
ref.opts.visible = config.display.chat.show_references | |
end | |
table.insert(self.Chat.refs, ref) | |
-- If it's a buffer reference and it's being watched, start watching | |
if ref.bufnr and ref.opts.watched then | |
self.Chat.watchers:watch(ref.bufnr) | |
end | |
end | |
local parsed_buffer = ts_parse_buffer(self.Chat) | |
if parsed_buffer then | |
-- If the reference block already exists, add to it | |
if parsed_buffer.capture == "refs" then | |
add(self.Chat, ref, parsed_buffer.end_row - 1) | |
-- If there are no references then add a new block below the heading | |
elseif parsed_buffer.capture == "role" then | |
add(self.Chat, ref, parsed_buffer.end_row + 1) | |
end | |
end | |
end | |
---Clear any references from a message in the chat buffer to remove unnecessary | |
---context before it's sent to the LLM. | |
---@param message table | |
---@return table | |
function References:clear(message) | |
if vim.tbl_isempty(self.Chat.refs) or not config.display.chat.show_references then | |
return message or nil | |
end | |
local parser = vim.treesitter.get_string_parser(message.content, "markdown") | |
local query = vim.treesitter.query.get("markdown", "reference") | |
local root = parser:parse()[1]:root() | |
local refs = nil | |
for id, node in query:iter_captures(root, message.content) do | |
if query.captures[id] == "refs" then | |
refs = node | |
end | |
end | |
if refs then | |
local start_row, _, end_row, _ = refs:range() | |
message.content = vim.split(message.content, "\n") | |
for i = start_row, end_row do | |
message.content[i] = "" | |
end | |
message.content = vim.trim(table.concat(message.content, "\n")) | |
end | |
return message | |
end | |
---Render all the references in the chat buffer after a response from the LLM | |
---@return nil | |
function References:render() | |
local chat = self.Chat | |
if vim.tbl_isempty(chat.refs) then | |
return self | |
end | |
local start_row = chat.header_line + 1 | |
local lines = {} | |
table.insert(lines, "> Context:") | |
for _, ref in pairs(chat.refs) do | |
if not ref or (ref.opts and ref.opts.visible == false) then | |
goto continue | |
end | |
if ref.opts and ref.opts.pinned then | |
table.insert(lines, string.format("> - %s%s", icons.pinned, ref.id)) | |
elseif ref.opts and ref.opts.watched then | |
table.insert(lines, string.format("> - %s%s", icons.watched, ref.id)) | |
else | |
table.insert(lines, string.format("> - %s", ref.id)) | |
end | |
::continue:: | |
end | |
if #lines == 1 then | |
-- no ref added | |
return | |
end | |
table.insert(lines, "") | |
return api.nvim_buf_set_lines(chat.bufnr, start_row, start_row, false, lines) | |
end | |
---Make a unique ID from the buffer number | |
---@param bufnr number | |
---@return string | |
function References:make_id_from_buf(bufnr) | |
local bufname = api.nvim_buf_get_name(bufnr) | |
return vim.fn.fnamemodify(bufname, ":.") | |
end | |
---Determine if a reference can be pinned | |
---@param ref string | |
---@return boolean | |
function References:can_be_pinned(ref) | |
for _, pin in ipairs(allowed_pins) do | |
if ref:find(pin) then | |
return true | |
end | |
end | |
return false | |
end | |
---Determine if a reference can be watched | |
---@param ref string | |
---@return boolean | |
function References:can_be_watched(ref) | |
for _, watch in ipairs(allowed_watchers) do | |
if ref:find(watch) then | |
return true | |
end | |
end | |
return false | |
end | |
---Get the references from the chat buffer | |
---@return table | |
function References:get_from_chat() | |
local query = vim.treesitter.query.get("markdown", "reference") | |
local tree = self.Chat.parser:parse()[1] | |
local root = tree:root() | |
local refs = {} | |
local role = nil | |
local chat = self.Chat | |
for id, node in query:iter_captures(root, chat.bufnr, chat.header_line - 1, -1) do | |
if query.captures[id] == "role" then | |
role = helpers.format_role(vim.treesitter.get_node_text(node, chat.bufnr)) | |
elseif role == user_role and query.captures[id] == "ref" then | |
local ref = vim.treesitter.get_node_text(node, chat.bufnr) | |
-- Clean both pinned and watched icons | |
ref = vim.iter(vim.tbl_values(icons)):fold(select(1, ref:gsub("^> %- ", "")), function(acc, icon) | |
return select(1, acc:gsub(icon, "")) | |
end) | |
table.insert(refs, vim.trim(ref)) | |
end | |
end | |
return refs | |
end | |
return References | |
--- | |
File: /lua/codecompanion/strategies/chat/subscribers.lua | |
--- | |
--[[ | |
Subscribers are functions that are set from outside the chat buffer that can be | |
executed at the end of every response. This is used in workflows, allowing | |
for consecutive prompts to be sent and even automatically submitted. | |
]] | |
local config = require("codecompanion.config") | |
local log = require("codecompanion.utils.log") | |
---@class CodeCompanion.Subscribers | |
local Subscribers = {} | |
---@param args CodeCompanion.SubscribersArgs | |
function Subscribers.new(args) | |
return setmetatable({ | |
queue = {}, | |
stopped = false, | |
}, { __index = Subscribers }) | |
end | |
---Link a subscriber to the chat buffer | |
---@param event CodeCompanion.Chat.Event | |
---@return nil | |
function Subscribers:subscribe(event) | |
event.id = math.random(10000000) | |
table.insert(self.queue, event) | |
end | |
---Unsubscribe an object from a chat buffer | |
---@param event CodeCompanion.Chat.Event | |
---@return nil | |
function Subscribers:unsubscribe(event) | |
local name = event.data and event.data.name or "" | |
for i, subscriber in ipairs(self.queue) do | |
if subscriber.id == event.id then | |
log:debug("[Subscription] Unsubscribing %s (%s)", name, event.id) | |
table.remove(self.queue, i) | |
end | |
end | |
end | |
---Does the chat buffer have any subscribers? | |
---@return boolean | |
function Subscribers:has_subscribers() | |
return #self.queue > 0 | |
end | |
---Execute the subscriber's callback | |
---@param chat CodeCompanion.Chat | |
---@param event CodeCompanion.Chat.Event | |
---@return nil | |
function Subscribers:action(chat, event) | |
local name = event.data and event.data.name or "" | |
if type(event.reuse) == "function" then | |
local reuse = event.reuse(chat) | |
if reuse then | |
log:debug("[Subscription] Reusing %s (%s)", name, event.id) | |
return event.callback(chat) | |
end | |
self:unsubscribe(event) | |
-- Don't auto submit a reuse prompt if it's the last one | |
if vim.tbl_isempty(self.queue) then | |
self.stopped = true | |
end | |
return | |
end | |
log:debug("[Subscription] Actioning: %s (%s)", name, event.id) | |
event.callback(chat) | |
if event.data and event.data.type == "once" then | |
return self:unsubscribe(event) | |
end | |
end | |
---Process the next subscriber in the queue | |
---@param chat CodeCompanion.Chat | |
---@return nil | |
function Subscribers:process(chat) | |
if not self:has_subscribers() then | |
return | |
end | |
vim.iter(self.queue):each(function(subscriber) | |
local name = subscriber.data and subscriber.data.name or "" | |
if not subscriber.order or subscriber.order < chat.cycle then | |
if type(subscriber.data.condition) == "function" then | |
if not subscriber.data.condition(chat) then | |
return log:debug("[Subscription] Condition not met: %s (%s)", name, subscriber.id) | |
end | |
end | |
self:action(chat, subscriber) | |
self:submit(chat, subscriber) | |
end | |
end) | |
end | |
---Automatically submit the chat buffer | |
---@param chat CodeCompanion.Chat | |
---@param subscriber CodeCompanion.Chat.Event | |
function Subscribers:submit(chat, subscriber) | |
if subscriber.data and subscriber.data.opts and subscriber.data.opts.auto_submit and not self.stopped then | |
if chat.current_request ~= nil then | |
return log:debug("[Subscription] Prevent double submit") | |
end | |
-- Defer the call to prevent rate limit bans | |
vim.defer_fn(function() | |
log:debug("[Subscription] Auto-submitting") | |
chat:submit() | |
end, config.opts.submit_delay) | |
end | |
end | |
---When a request has been stopped, we should stop any automatic subscribers | |
---@return CodeCompanion.Subscribers | |
function Subscribers:stop() | |
log:debug("[Subscription] Stopping") | |
self.stopped = true | |
return self | |
end | |
return Subscribers | |
--- | |
File: /lua/codecompanion/strategies/chat/tools.lua | |
--- | |
--[[ | |
Methods for handling interactions between the chat buffer and tools | |
--]] | |
---@class CodeCompanion.Chat.Tools | |
---@field chat CodeCompanion.Chat | |
---@field flags table Flags that external functions can update and subscribers can interact with | |
---@field in_use table<string, boolean> Tools that are in use on the chat buffer | |
---@field schemas table<string, table> The config for the tools in use | |
---@class CodeCompanion.Chat.Tools | |
local Tools = {} | |
local config = require("codecompanion.config") | |
local util = require("codecompanion.utils") | |
---@class CodeCompanion.Chat.ToolsArgs | |
---@field chat CodeCompanion.Chat | |
---@param args CodeCompanion.Chat.ToolsArgs | |
function Tools.new(args) | |
local self = setmetatable({ | |
chat = args.chat, | |
flags = {}, | |
in_use = {}, | |
schemas = {}, | |
}, { __index = Tools }) | |
return self | |
end | |
---Add a reference to the tool in the chat buffer | |
---@param chat CodeCompanion.Chat The chat buffer | |
---@param id string The id of the tool | |
---@param opts? table Optional parameters for the reference | |
---@return nil | |
local function add_reference(chat, id, opts) | |
chat.references:add({ | |
source = "tool", | |
name = "tool", | |
id = id, | |
opts = opts, | |
}) | |
end | |
---Add the tool's system prompt to the chat buffer | |
---@param chat CodeCompanion.Chat The chat buffer | |
---@param tool table the resolved tool | |
---@param id string The id of the tool | |
---@return nil | |
local function add_system_prompt(chat, tool, id) | |
if tool and tool.system_prompt then | |
local system_prompt | |
if type(tool.system_prompt) == "function" then | |
system_prompt = tool.system_prompt(tool.schema) | |
elseif type(tool.system_prompt) == "string" then | |
system_prompt = tostring(tool.system_prompt) | |
end | |
chat:add_message( | |
{ role = config.constants.SYSTEM_ROLE, content = system_prompt }, | |
{ visible = false, tag = "tool", reference = id } | |
) | |
end | |
end | |
---Add the tool's schema to the chat buffer | |
---@param self CodeCompanion.Chat.Tools The tools object | |
---@param tool table The resolved tool | |
---@param id string The id of the tool | |
---@return nil | |
local function add_schema(self, tool, id) | |
self.schemas[id] = tool.schema | |
end | |
---Add the given tool to the chat buffer | |
---@param tool string The name of the tool | |
---@param tool_config table The tool from the config | |
---@param opts? table Optional parameters | |
---@return nil | |
function Tools:add(tool, tool_config, opts) | |
opts = opts or { | |
visible = true, | |
} | |
local resolved_tool = self.chat.agents.resolve(tool_config) | |
if not resolved_tool or self.in_use[tool] then | |
return | |
end | |
local id = "<tool>" .. tool .. "</tool>" | |
add_reference(self.chat, id, opts) | |
add_system_prompt(self.chat, resolved_tool, id) | |
add_schema(self, resolved_tool, id) | |
util.fire("ChatToolAdded", { bufnr = self.chat.bufnr, id = self.chat.id, tool = tool }) | |
self.in_use[tool] = true | |
return self | |
end | |
---Add tools from a group to the chat buffer | |
---@param group string The name of the group | |
---@param tools_config table The tools configuration | |
---@return nil | |
function Tools:add_group(group, tools_config) | |
local group_config = tools_config.groups[group] | |
if not group_config or not group_config.tools then | |
return | |
end | |
local opts = vim.tbl_deep_extend("force", { collapse_tools = true }, group_config.opts or {}) | |
local collapse_tools = opts.collapse_tools | |
local group_id = "<group>" .. group .. "</group>" | |
local system_prompt = group_config.system_prompt | |
if type(system_prompt) == "function" then | |
system_prompt = system_prompt(group_config) | |
end | |
if system_prompt then | |
self.chat:add_message({ | |
role = config.constants.SYSTEM_ROLE, | |
content = system_prompt, | |
}, { tag = "tool", visible = false, reference = group_id }) | |
end | |
if collapse_tools then | |
add_reference(self.chat, group_id) | |
end | |
for _, tool in ipairs(group_config.tools) do | |
self:add(tool, tools_config[tool], { visible = not collapse_tools }) | |
end | |
end | |
---Determine if the chat buffer has any tools in use | |
---@return boolean | |
function Tools:loaded() | |
return not vim.tbl_isempty(self.in_use) | |
end | |
---Clear the tools | |
---@return nil | |
function Tools:clear() | |
self.flags = {} | |
self.in_use = {} | |
self.schemas = {} | |
end | |
return Tools | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment