forked from shycatj5/Auto-GPT
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'command_registry' of https://github.com/kreneskyp/Auto-GPT
- Loading branch information
Showing
18 changed files
with
439 additions
and
5 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os | ||
import sys | ||
import importlib | ||
import inspect | ||
from typing import Callable, Any, List | ||
|
||
# Unique identifier for auto-gpt commands | ||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command" | ||
|
||
class Command: | ||
"""A class representing a command. | ||
Attributes: | ||
name (str): The name of the command. | ||
description (str): A brief description of what the command does. | ||
signature (str): The signature of the function that the command executes. Defaults to None. | ||
""" | ||
|
||
def __init__(self, name: str, description: str, method: Callable[..., Any], signature: str = None): | ||
self.name = name | ||
self.description = description | ||
self.method = method | ||
self.signature = signature if signature else str(inspect.signature(self.method)) | ||
|
||
def __call__(self, *args, **kwargs) -> Any: | ||
return self.method(*args, **kwargs) | ||
|
||
def __str__(self) -> str: | ||
return f"{self.name}: {self.description}, args: {self.signature}" | ||
|
||
class CommandRegistry: | ||
""" | ||
The CommandRegistry class is a manager for a collection of Command objects. | ||
It allows the registration, modification, and retrieval of Command objects, | ||
as well as the scanning and loading of command plugins from a specified | ||
directory. | ||
""" | ||
|
||
def __init__(self): | ||
self.commands = {} | ||
|
||
def _import_module(self, module_name: str) -> Any: | ||
return importlib.import_module(module_name) | ||
|
||
def _reload_module(self, module: Any) -> Any: | ||
return importlib.reload(module) | ||
|
||
def register(self, cmd: Command) -> None: | ||
self.commands[cmd.name] = cmd | ||
|
||
def unregister(self, command_name: str): | ||
if command_name in self.commands: | ||
del self.commands[command_name] | ||
else: | ||
raise KeyError(f"Command '{command_name}' not found in registry.") | ||
|
||
def reload_commands(self) -> None: | ||
"""Reloads all loaded command plugins.""" | ||
for cmd_name in self.commands: | ||
cmd = self.commands[cmd_name] | ||
module = self._import_module(cmd.__module__) | ||
reloaded_module = self._reload_module(module) | ||
if hasattr(reloaded_module, "register"): | ||
reloaded_module.register(self) | ||
|
||
def get_command(self, name: str) -> Callable[..., Any]: | ||
return self.commands[name] | ||
|
||
def call(self, command_name: str, **kwargs) -> Any: | ||
if command_name not in self.commands: | ||
raise KeyError(f"Command '{command_name}' not found in registry.") | ||
command = self.commands[command_name] | ||
return command(**kwargs) | ||
|
||
def command_prompt(self) -> str: | ||
""" | ||
Returns a string representation of all registered `Command` objects for use in a prompt | ||
""" | ||
commands_list = [f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())] | ||
return "\n".join(commands_list) | ||
|
||
def import_commands(self, module_name: str) -> None: | ||
""" | ||
Imports the specified Python module containing command plugins. | ||
This method imports the associated module and registers any functions or | ||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute | ||
as `Command` objects. The registered `Command` objects are then added to the | ||
`commands` dictionary of the `CommandRegistry` object. | ||
Args: | ||
module_name (str): The name of the module to import for command plugins. | ||
""" | ||
|
||
module = importlib.import_module(module_name) | ||
|
||
for attr_name in dir(module): | ||
attr = getattr(module, attr_name) | ||
# Register decorated functions | ||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(attr, AUTO_GPT_COMMAND_IDENTIFIER): | ||
self.register(attr.command) | ||
# Register command classes | ||
elif inspect.isclass(attr) and issubclass(attr, Command) and attr != Command: | ||
cmd_instance = attr() | ||
self.register(cmd_instance) | ||
|
||
def command(name: str, description: str, signature: str = None) -> Callable[..., Any]: | ||
"""The command decorator is used to create Command objects from ordinary functions.""" | ||
def decorator(func: Callable[..., Any]) -> Command: | ||
cmd = Command(name=name, description=description, method=func, signature=signature) | ||
|
||
def wrapper(*args, **kwargs) -> Any: | ||
return func(*args, **kwargs) | ||
|
||
wrapper.command = cmd | ||
|
||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True) | ||
return wrapper | ||
|
||
return decorator | ||
|
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from auto_gpt.commands import Command, command | ||
|
||
|
||
@command('function_based', 'Function-based test command') | ||
def function_based(arg1: int, arg2: str) -> str: | ||
return f'{arg1} - {arg2}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import shutil | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
from auto_gpt.commands import Command, CommandRegistry | ||
|
||
|
||
class TestCommand: | ||
@staticmethod | ||
def example_function(arg1: int, arg2: str) -> str: | ||
return f"{arg1} - {arg2}" | ||
|
||
def test_command_creation(self): | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
assert cmd.name == "example" | ||
assert cmd.description == "Example command" | ||
assert cmd.method == self.example_function | ||
assert cmd.signature == "(arg1: int, arg2: str) -> str" | ||
|
||
def test_command_call(self): | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
result = cmd(arg1=1, arg2="test") | ||
assert result == "1 - test" | ||
|
||
def test_command_call_with_invalid_arguments(self): | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
with pytest.raises(TypeError): | ||
cmd(arg1="invalid", does_not_exist="test") | ||
|
||
def test_command_default_signature(self): | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
assert cmd.signature == "(arg1: int, arg2: str) -> str" | ||
|
||
def test_command_custom_signature(self): | ||
custom_signature = "custom_arg1: int, custom_arg2: str" | ||
cmd = Command(name="example", description="Example command", method=self.example_function, signature=custom_signature) | ||
|
||
assert cmd.signature == custom_signature | ||
|
||
|
||
|
||
class TestCommandRegistry: | ||
@staticmethod | ||
def example_function(arg1: int, arg2: str) -> str: | ||
return f"{arg1} - {arg2}" | ||
|
||
def test_register_command(self): | ||
"""Test that a command can be registered to the registry.""" | ||
registry = CommandRegistry() | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
registry.register(cmd) | ||
|
||
assert cmd.name in registry.commands | ||
assert registry.commands[cmd.name] == cmd | ||
|
||
def test_unregister_command(self): | ||
"""Test that a command can be unregistered from the registry.""" | ||
registry = CommandRegistry() | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
registry.register(cmd) | ||
registry.unregister(cmd.name) | ||
|
||
assert cmd.name not in registry.commands | ||
|
||
def test_get_command(self): | ||
"""Test that a command can be retrieved from the registry.""" | ||
registry = CommandRegistry() | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
registry.register(cmd) | ||
retrieved_cmd = registry.get_command(cmd.name) | ||
|
||
assert retrieved_cmd == cmd | ||
|
||
def test_get_nonexistent_command(self): | ||
"""Test that attempting to get a nonexistent command raises a KeyError.""" | ||
registry = CommandRegistry() | ||
|
||
with pytest.raises(KeyError): | ||
registry.get_command("nonexistent_command") | ||
|
||
def test_call_command(self): | ||
"""Test that a command can be called through the registry.""" | ||
registry = CommandRegistry() | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
registry.register(cmd) | ||
result = registry.call("example", arg1=1, arg2="test") | ||
|
||
assert result == "1 - test" | ||
|
||
def test_call_nonexistent_command(self): | ||
"""Test that attempting to call a nonexistent command raises a KeyError.""" | ||
registry = CommandRegistry() | ||
|
||
with pytest.raises(KeyError): | ||
registry.call("nonexistent_command", arg1=1, arg2="test") | ||
|
||
def test_get_command_prompt(self): | ||
"""Test that the command prompt is correctly formatted.""" | ||
registry = CommandRegistry() | ||
cmd = Command(name="example", description="Example command", method=self.example_function) | ||
|
||
registry.register(cmd) | ||
command_prompt = registry.command_prompt() | ||
|
||
assert f"(arg1: int, arg2: str)" in command_prompt | ||
|
||
def test_import_mock_commands_module(self): | ||
"""Test that the registry can import a module with mock command plugins.""" | ||
registry = CommandRegistry() | ||
mock_commands_module = "auto_gpt.tests.mocks.mock_commands" | ||
|
||
registry.import_commands(mock_commands_module) | ||
|
||
assert "function_based" in registry.commands | ||
assert registry.commands["function_based"].name == "function_based" | ||
assert registry.commands["function_based"].description == "Function-based test command" | ||
|
||
def test_import_temp_command_file_module(self, tmp_path): | ||
"""Test that the registry can import a command plugins module from a temp file.""" | ||
registry = CommandRegistry() | ||
|
||
# Create a temp command file | ||
src = Path("/app/auto_gpt/tests/mocks/mock_commands.py") | ||
temp_commands_file = tmp_path / "mock_commands.py" | ||
shutil.copyfile(src, temp_commands_file) | ||
|
||
# Add the temp directory to sys.path to make the module importable | ||
sys.path.append(str(tmp_path)) | ||
|
||
temp_commands_module = "mock_commands" | ||
registry.import_commands(temp_commands_module) | ||
|
||
# Remove the temp directory from sys.path | ||
sys.path.remove(str(tmp_path)) | ||
|
||
assert "function_based" in registry.commands | ||
assert registry.commands["function_based"].name == "function_based" | ||
assert registry.commands["function_based"].description == "Function-based test command" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.