Skip to content

Commit

Permalink
Add parallel action execution for ToolkitTask (griptape-ai#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov authored Mar 14, 2024
1 parent 661bf45 commit f80df06
Show file tree
Hide file tree
Showing 38 changed files with 772 additions and 646 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- Every subtask in `ToolkitTask` can now execute multiple actions in parallel.
- Added `BaseActionSubtaskEvent.subtask_actions`.
- Support for `text-embedding-3-small` and `text-embedding-3-large` models.

### Fixed
- Improved system prompt in `ToolTask` to support more use cases.

### Changed
- **BREAKING**: `ActionSubtask` was renamed to `ActionsSubtask`.
- **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`.
- Default embedding model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`.
- Default embedding model of `OpenAiStructureConfig` to `text-embedding-3-small`.
- `BaseTextLoader` to accept a `BaseChunker`.
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ And here is the output:
Thought: First, I need to use the WebScraper action to
load the content of the webpage.
Action: {"name": "WebScraper", "path":
Actions: {"actions": [{"name": "WebScraper", "path":
"get_content", "input": {"values": {"url":
"https://griptape.ai"}}}
"https://griptape.ai"}}}]}
INFO Subtask f2cd3cfecaeb4001a0d3eccad32c2d07
Response: Output of "WebScraper.get_content" was
stored in memory with memory_name "TaskMemory" and
Expand All @@ -76,10 +76,10 @@ And here is the output:
Thought: Now that the webpage content is stored in
memory, I need to use the TaskMemoryClient action
to summarize the content.
Action: {"name": "TaskMemoryClient", "path":
Actions: {"actions": [{"name": "TaskMemoryClient", "path":
"summarize", "input": {"values": {"memory_name":
"TaskMemory", "artifact_namespace":
"c497d83c1d134db694b9994596016320"}}}
"c497d83c1d134db694b9994596016320"}}}]}
[11/02/23 15:29:06] INFO Subtask 0096dac0f0524636be197e06a37f8aa0
Response: Output of "TaskMemoryClient.summarize"
was stored in memory with memory_name "TaskMemory"
Expand All @@ -89,12 +89,12 @@ And here is the output:
Thought: Now that the summary is stored in memory,
I need to use the FileManager action to save the
summary to a file named griptape.txt.
Action: {"name": "FileManager", "path":
Actions: {"actions": [{"name": "FileManager", "path":
"save_memory_artifacts_to_disk", "input":
{"values": {"dir_name": ".", "file_name":
"griptape.txt", "memory_name": "TaskMemory",
"artifact_namespace":
"77584322d33d40e992da9767d02a9018"}}}
"77584322d33d40e992da9767d02a9018"}}}]}
INFO Subtask 7cc3d96500ce4efdac085c07c7370822
Response: saved successfully
[11/02/23 15:29:30] INFO ToolkitTask 72b89a905be84245a0563b206795ac73
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
self._extract_ratelimit_metadata(result)

parsed_result = result.parse()

if len(parsed_result.choices) == 1:
return TextArtifact(value=parsed_result.choices[0].message.content.strip())
else:
Expand Down
12 changes: 6 additions & 6 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from .base_task_event import BaseTaskEvent
from .start_task_event import StartTaskEvent
from .finish_task_event import FinishTaskEvent
from .base_action_subtask_event import BaseActionSubtaskEvent
from .start_action_subtask_event import StartActionSubtaskEvent
from .finish_action_subtask_event import FinishActionSubtaskEvent
from .base_actions_subtask_event import BaseActionsSubtaskEvent
from .start_actions_subtask_event import StartActionsSubtaskEvent
from .finish_actions_subtask_event import FinishActionsSubtaskEvent
from .base_prompt_event import BasePromptEvent
from .start_prompt_event import StartPromptEvent
from .finish_prompt_event import FinishPromptEvent
Expand All @@ -23,9 +23,9 @@
"BaseTaskEvent",
"StartTaskEvent",
"FinishTaskEvent",
"BaseActionSubtaskEvent",
"StartActionSubtaskEvent",
"FinishActionSubtaskEvent",
"BaseActionsSubtaskEvent",
"StartActionsSubtaskEvent",
"FinishActionsSubtaskEvent",
"BasePromptEvent",
"StartPromptEvent",
"FinishPromptEvent",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,19 @@
from typing import TYPE_CHECKING, Optional
from .base_task_event import BaseTaskEvent


if TYPE_CHECKING:
from griptape.tasks import BaseTask, ActionSubtask
from griptape.tasks import BaseTask, ActionsSubtask


@define
class BaseActionSubtaskEvent(BaseTaskEvent, ABC):
class BaseActionsSubtaskEvent(BaseTaskEvent, ABC):
subtask_parent_task_id: Optional[str] = field(kw_only=True, metadata={"serializable": True})
subtask_thought: Optional[str] = field(kw_only=True, metadata={"serializable": True})
subtask_action_name: Optional[str] = field(kw_only=True, metadata={"serializable": True})
subtask_action_path: Optional[str] = field(kw_only=True, metadata={"serializable": True})
subtask_action_input: Optional[dict] = field(kw_only=True, metadata={"serializable": True})
subtask_actions: Optional[list[dict]] = field(kw_only=True, metadata={"serializable": True})

@classmethod
def from_task(cls, task: BaseTask) -> BaseActionSubtaskEvent:
if not isinstance(task, ActionSubtask):
def from_task(cls, task: BaseTask) -> BaseActionsSubtaskEvent:
if not isinstance(task, ActionsSubtask):
raise ValueError("Event must be of instance ActionSubtask.")
return cls(
task_id=task.id,
Expand All @@ -29,7 +26,5 @@ def from_task(cls, task: BaseTask) -> BaseActionSubtaskEvent:
task_output=task.output,
subtask_parent_task_id=task.parent_task_id,
subtask_thought=task.thought,
subtask_action_name=task.action_name,
subtask_action_path=task.action_path,
subtask_action_input=task.action_input,
subtask_actions=task.actions_to_dicts(),
)
8 changes: 0 additions & 8 deletions griptape/events/finish_action_subtask_event.py

This file was deleted.

8 changes: 8 additions & 0 deletions griptape/events/finish_actions_subtask_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations
from attrs import define
from .base_actions_subtask_event import BaseActionsSubtaskEvent


@define
class FinishActionsSubtaskEvent(BaseActionsSubtaskEvent):
...
8 changes: 0 additions & 8 deletions griptape/events/start_action_subtask_event.py

This file was deleted.

8 changes: 8 additions & 0 deletions griptape/events/start_actions_subtask_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations
from attrs import define
from .base_actions_subtask_event import BaseActionsSubtaskEvent


@define
class StartActionsSubtaskEvent(BaseActionsSubtaskEvent):
...
4 changes: 2 additions & 2 deletions griptape/memory/meta/action_subtask_meta_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class ActionSubtaskMetaEntry(BaseMetaEntry):
Attributes:
thought: CoT thought string from the LLM.
action: ReAct action JSON string from the LLM.
actions: ReAct actions JSON string from the LLM.
answer: tool-generated and memory-processed response from Griptape.
"""

type: str = field(default=BaseMetaEntry.__name__, kw_only=True, metadata={"serializable": False})
thought: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
action: str = field(kw_only=True, metadata={"serializable": True})
actions: str = field(kw_only=True, metadata={"serializable": True})
answer: str = field(kw_only=True, metadata={"serializable": True})
10 changes: 6 additions & 4 deletions griptape/memory/task/task_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
from griptape.memory.task.storage import BaseArtifactStorage
from griptape.tasks import ActionSubtask
from griptape.tasks import ActionsSubtask


@define
Expand Down Expand Up @@ -39,7 +39,7 @@ def get_storage_for(self, artifact: BaseArtifact) -> Optional[BaseArtifactStorag
return find_storage(artifact)

def process_output(
self, tool_activity: Callable, subtask: ActionSubtask, output_artifact: BaseArtifact
self, tool_activity: Callable, subtask: ActionsSubtask, output_artifact: BaseArtifact
) -> BaseArtifact:
from griptape.utils import J2

Expand All @@ -53,7 +53,7 @@ def process_output(
if result:
return result
else:
self.namespace_metadata[namespace] = subtask.action_to_json()
self.namespace_metadata[namespace] = subtask.actions_to_json()

output = J2("memory/tool.j2").render(
memory_name=self.name,
Expand All @@ -64,7 +64,9 @@ def process_output(

if subtask.structure and subtask.structure.meta_memory:
subtask.structure.meta_memory.add_entry(
ActionSubtaskMetaEntry(thought=subtask.thought, action=subtask.action_to_json(), answer=output)
ActionSubtaskMetaEntry(
thought=subtask.thought, actions=subtask.actions_to_json(), answer=output
)
)

return InfoArtifact(output, name=namespace)
Expand Down
4 changes: 2 additions & 2 deletions griptape/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .activity_mixin import ActivityMixin
from .exponential_backoff_mixin import ExponentialBackoffMixin
from .action_subtask_origin_mixin import ActionSubtaskOriginMixin
from .actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from .rule_mixin import RuleMixin
from .serializable_mixin import SerializableMixin
from .image_artifact_file_output_mixin import ImageArtifactFileOutputMixin

__all__ = [
"ActivityMixin",
"ExponentialBackoffMixin",
"ActionSubtaskOriginMixin",
"ActionsSubtaskOriginMixin",
"RuleMixin",
"ImageArtifactFileOutputMixin",
"SerializableMixin",
Expand Down
28 changes: 0 additions & 28 deletions griptape/mixins/action_subtask_origin_mixin.py

This file was deleted.

53 changes: 53 additions & 0 deletions griptape/mixins/actions_subtask_origin_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import abstractmethod
from attr import define
from schema import Schema, Literal

if TYPE_CHECKING:
from griptape.memory import TaskMemory
from griptape.tools import BaseTool
from griptape.tasks import ActionsSubtask


@define(slots=False)
class ActionsSubtaskOriginMixin:
@abstractmethod
def find_tool(self, tool_name: str) -> BaseTool:
...

@abstractmethod
def find_memory(self, memory_name: str) -> TaskMemory:
...

@abstractmethod
def find_subtask(self, subtask_id: str) -> ActionsSubtask:
...

@abstractmethod
def add_subtask(self, subtask: ActionsSubtask) -> ActionsSubtask:
...

@abstractmethod
def actions_schema(self) -> dict:
...

def _actions_schema_for_tools(self, tools: list[BaseTool]) -> dict:
action_schemas = []

for tool in tools:
for activity_schema in tool.activity_schemas():
action_schema = activity_schema.schema
output_label_key = Literal(
"output_label", description="Action label that can later be used to identify action output"
)

action_schema[output_label_key] = str

action_schemas.append(action_schema)

actions_schema = Schema(
description="JSON schema for an array of actions to be executed in parallel.", schema=action_schemas
)

return actions_schema.json_schema("Actions Schema")
4 changes: 2 additions & 2 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .base_task import BaseTask
from .base_text_input_task import BaseTextInputTask
from .prompt_task import PromptTask
from .action_subtask import ActionSubtask
from .actions_subtask import ActionsSubtask
from .toolkit_task import ToolkitTask
from .text_summary_task import TextSummaryTask
from .tool_task import ToolTask
Expand All @@ -21,7 +21,7 @@
"BaseTask",
"BaseTextInputTask",
"PromptTask",
"ActionSubtask",
"ActionsSubtask",
"ToolkitTask",
"TextSummaryTask",
"ToolTask",
Expand Down
Loading

0 comments on commit f80df06

Please sign in to comment.