forked from run-llama/llama-lab
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial baby-agi clone w/ llama index
- Loading branch information
1 parent
99995d2
commit 2ce0192
Showing
13 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from langchain.chains import LLMChain | ||
from langchain.llms import OpenAI | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.prompts import PromptTemplate | ||
|
||
from agi.task_prompts import LC_EXECUTION_PROMPT | ||
|
||
|
||
class SimpleExecutionAgent: | ||
def __init__(self, llm=None, model_name="text-davinci-003"): | ||
if llm: | ||
self._llm = llm | ||
elif model_name == "text-davinci-003": | ||
self._llm = OpenAI(temperature=0, model_name=model_name, max_tokens=512) | ||
else: | ||
self._llm = ChatOpenAI(temperature=0, model_name=model_name, max_tokens=512 ) | ||
|
||
self._prompt_template = PromptTemplate( | ||
template=LC_EXECUTION_PROMPT, | ||
input_variables=["task", "objective", "completed_tasks_summary"] | ||
) | ||
self._execution_chain = LLMChain(llm=self._llm, prompt=self._prompt_template) | ||
|
||
def execute_task(self, objective, task, completed_tasks_summary): | ||
result = self._execution_chain.predict( | ||
objective=objective, | ||
task=task, | ||
completed_tasks_summary=completed_tasks_summary | ||
) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import re | ||
import json | ||
from typing import List, Tuple | ||
from llama_index import Document | ||
from llama_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt | ||
|
||
from agi.utils import initialize_task_list_index | ||
from agi.task_prompts import ( | ||
DEFAULT_TASK_PRIORITIZE_TMPL, | ||
DEFAULT_REFINE_TASK_PRIORITIZE_TMPL, | ||
DEFAULT_TASK_CREATE_TMPL, | ||
DEFAULT_REFINE_TASK_CREATE_TMPL, | ||
NO_COMPLETED_TASKS_SUMMARY | ||
) | ||
|
||
|
||
class TaskManager: | ||
def __init__(self, tasks: List[str]) -> None: | ||
self.current_tasks = [Document(x) for x in tasks] | ||
self.completed_tasks = [] | ||
self.current_tasks_index = initialize_task_list_index(self.current_tasks, index_path="current_tasks_index.json") | ||
self.completed_tasks_index = initialize_task_list_index(self.completed_tasks, index_path="completed_tasks_index.json") | ||
|
||
def _get_task_create_templates(self, prev_task: str, prev_result: str) -> Tuple[str, str]: | ||
text_qa_template = DEFAULT_TASK_CREATE_TMPL.format( | ||
prev_result=prev_result, prev_task=prev_task, query_str='{query_str}', context_str="{context_str}" | ||
) | ||
text_qa_template = QuestionAnswerPrompt(text_qa_template) | ||
|
||
refine_template = DEFAULT_REFINE_TASK_CREATE_TMPL.format( | ||
prev_result=prev_result, prev_task=prev_task, query_str='{query_str}', context_msg="{context_msg}", existing_answer='{existing_answer}' | ||
) | ||
refine_template = RefinePrompt(refine_template) | ||
|
||
return (text_qa_template, refine_template) | ||
|
||
def _get_task_prioritize_templates(self) -> Tuple[str, str]: | ||
return ( | ||
QuestionAnswerPrompt(DEFAULT_TASK_PRIORITIZE_TMPL), | ||
RefinePrompt(DEFAULT_REFINE_TASK_PRIORITIZE_TMPL) | ||
) | ||
|
||
def get_completed_tasks_summary(self) -> str: | ||
if len(self.completed_tasks) == 0: | ||
return NO_COMPLETED_TASKS_SUMMARY | ||
summary = self.completed_tasks_index.query("Summarize the current completed tasks", response_mode="tree_summarize") | ||
return str(summary) | ||
|
||
def prioritize_tasks(self, objective: str) -> None: | ||
(text_qa_template, refine_template) = self._get_task_prioritize_templates() | ||
prioritized_tasks = self.current_tasks_index.query( | ||
objective, | ||
text_qa_template=text_qa_template, | ||
refine_template=refine_template | ||
) | ||
|
||
new_tasks = [] | ||
for task in str(prioritized_tasks).split("\n"): | ||
task = re.sub(r"^[0-9]+\.", "", task).strip() | ||
if len(task) > 10: | ||
new_tasks.append(task) | ||
self.current_tasks = [Document(x) for x in new_tasks] | ||
self.current_tasks_index = initialize_task_list_index(self.current_tasks) | ||
|
||
def generate_new_tasks(self, objective: str, prev_task: str, prev_result: str) -> None: | ||
(text_qa_template, refine_template) = self._get_task_create_templates(prev_task, prev_result) | ||
new_tasks = self.completed_tasks_index.query(objective, text_qa_template=text_qa_template, refine_template=refine_template) | ||
try: | ||
new_tasks = json.loads(str(new_tasks)) | ||
new_tasks = [x.strip() for x in new_tasks if len(x.strip()) > 10] | ||
except: | ||
new_tasks = str(new_tasks).split('\n') | ||
new_tasks = [re.sub(r"^[0-9]+\.", "", x).strip() for x in str(new_tasks) if len(x.strip()) > 10 and x[0].isnumeric()] | ||
self.add_new_tasks(new_tasks) | ||
|
||
def get_next_task(self) -> str: | ||
next_task = self.current_tasks.pop().get_text() | ||
self.current_tasks_index = initialize_task_list_index(self.current_tasks) | ||
return next_task | ||
|
||
def add_new_tasks(self, tasks: List[str]) -> None: | ||
for task in tasks: | ||
if task not in self.current_tasks and task not in self.completed_tasks: | ||
self.current_tasks.append(Document(task)) | ||
self.current_tasks_index = initialize_task_list_index(self.current_tasks) | ||
|
||
def add_completed_task(self, task: str, result: str) -> None: | ||
document = Document(f"Task: {task}\nResult: {result}\n") | ||
self.completed_tasks.append(document) | ||
self.completed_tasks_index = initialize_task_list_index(self.completed_tasks) |
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model | ||
from langchain.prompts.chat import ( | ||
AIMessagePromptTemplate, | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
SystemMessagePromptTemplate | ||
) | ||
|
||
############################################# | ||
##### AGI Prefix ##### | ||
############################################# | ||
PREFIX = ( | ||
"You are an autonomous artificial intelligence, capable of planning and executing tasks to achieve an objective.\n" | ||
"When given an objective, you can plan and execute any number tasks that will help achieve your original objective.\n" | ||
) | ||
|
||
|
||
############################################# | ||
##### Initial Completed Tasks Summary ##### | ||
############################################# | ||
NO_COMPLETED_TASKS_SUMMARY = "You haven't completed any tasks yet." | ||
|
||
|
||
############################################# | ||
##### Langchain - Execution Agent (Unused Currently) ##### | ||
############################################# | ||
LC_PREFIX = PREFIX + "You have access to the following tools:" | ||
|
||
LC_FORMAT_INSTRUCTIONS = """Use the following format: | ||
Task: the current task you must complete | ||
Thought: you should always think about what to do | ||
Action: the action to take, should be one of [{tool_names}] | ||
Action Input: the input to the action | ||
Observation: the result of the action | ||
... (this Thought/Action/Action Input/Observation can repeat N times) | ||
Thought: I have now completed the task | ||
Final Answer: the final answer to the original input task""" | ||
|
||
LC_SUFFIX = ( | ||
"This is your current objective: {objective}\n" | ||
"Take into account what you have already achieved: {completed_tasks_summary}\n" | ||
"Using your current objective, your previously completed tasks, and your available tools," | ||
"Complete the current task.\n" | ||
"Begin!\n" | ||
"Task: {input}\n" | ||
"Thought: {agent_scratchpad}" | ||
) | ||
|
||
|
||
############################################# | ||
##### Langchain - Execution Chain ##### | ||
############################################# | ||
LC_EXECUTION_PROMPT = ( | ||
"You are an AI who performs one task based on the following objective: {objective}\n." | ||
"Take into account this summary of previously completed tasks: {completed_tasks_summary}\n." | ||
"Your task: {task}\n" | ||
"Response: " | ||
) | ||
|
||
|
||
############################################# | ||
##### LlamaIndex -- Task Creation ##### | ||
############################################# | ||
DEFAULT_TASK_CREATE_TMPL = ( | ||
f"{PREFIX}" | ||
"Your current objective is as follows: {query_str}\n" | ||
"Most recently, you completed the task '{prev_task}', which had the result of '{prev_result}'. " | ||
"A description of your current incomplete tasks are below: \n" | ||
"---------------------\n" | ||
"{context_str}" | ||
"\n---------------------\n" | ||
"Given the current objective, the current incomplete tasks, and the latest completed task, " | ||
"create new tasks to be completed that do not overlap with incomplete tasks. " | ||
"Return the tasks as an array." | ||
) | ||
#TASK_CREATE_PROMPT = QuestionAnswerPrompt(DEFAULT_TASK_CREATE_TMPL) | ||
|
||
DEFAULT_REFINE_TASK_CREATE_TMPL = ( | ||
f"{PREFIX}" | ||
"Your current objective is as follows: {query_str}\n" | ||
"Most recently, you completed the task '{prev_task}', which had the result of '{prev_result}'. " | ||
"A description of your current incomplete tasks are below: \n" | ||
"---------------------\n" | ||
"{context_msg}" | ||
"\n---------------------\n" | ||
"Currently, you have created the following new tasks: {existing_answer}" | ||
"Given the current objective, the current incomplete tasks, list of newly created tasks, and the latest completed task, " | ||
"add new tasks to be completed that do not overlap with incomplete tasks. " | ||
"Return the tasks as an array. If you have no more tasks to add, repeat the existing list of new tasks." | ||
) | ||
#REFINE_TASK_CREATE_PROMPT = RefinePrompt(DEFAULT_REFINE_TASK_CREATE_TMPL) | ||
|
||
|
||
############################################# | ||
##### LlamaIndex -- Task Prioritization ##### | ||
############################################# | ||
DEFAULT_TASK_PRIORITIZE_TMPL = ( | ||
f"{PREFIX}" | ||
"Your current objective is as follows: {query_str}\n" | ||
"A list of your current incomplete tasks are below: \n" | ||
"---------------------\n" | ||
"{context_str}" | ||
"\n---------------------\n" | ||
"Given the current objective, prioritize the current list of tasks. " | ||
"Do not remove or add any tasks. Return the results as a numbered list, like:\n" | ||
"#. First task\n" | ||
"#. Second task\n" | ||
"... continue until all tasks are prioritized. " | ||
"Start the task list with number 1." | ||
) | ||
|
||
DEFAULT_REFINE_TASK_PRIORITIZE_TMPL = ( | ||
f"{PREFIX}" | ||
"Your current objective is as follows: {query_str}\n" | ||
"A list of additional incomplete tasks are below: \n" | ||
"---------------------\n" | ||
"{context_msg}" | ||
"\n---------------------\n" | ||
"Currently, you also have the following list of prioritized tasks: {existing_answer}" | ||
"Given the current objective and existing list, prioritize the current list of tasks. " | ||
"Do not remove or add any tasks. Return the results as a numbered list, like:\n" | ||
"#. First task\n" | ||
"#. Second task\n" | ||
"... continue until all tasks are prioritized. " | ||
"Start the task list with number 1." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
from llama_index import ( | ||
GPTSimpleVectorIndex, | ||
GPTListIndex, | ||
Document, | ||
LLMPredictor, | ||
ServiceContext | ||
) | ||
|
||
|
||
def initialize_task_list_index( | ||
documents, | ||
llm_predictor=None, | ||
embed_model=None, | ||
prompt_helper=None, | ||
index_path="./index.json", | ||
chunk_size_limit=2000 | ||
): | ||
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, | ||
embed_model=embed_model, | ||
prompt_helper=prompt_helper, | ||
chunk_size_limit=chunk_size_limit) | ||
if os.path.exists(index_path): | ||
return GPTListIndex.load_from_disk(index_path, service_context=service_context) | ||
else: | ||
return GPTListIndex.from_documents(documents, service_context=service_context) | ||
|
||
|
||
def log_current_status(cur_task, result, completed_tasks_summary, task_list): | ||
status_string = f""" | ||
================================== | ||
Completed Tasks Summary: {completed_tasks_summary.strip()} | ||
Current Task: {cur_task.strip()} | ||
Result: {result.strip()} | ||
Task List: {", ".join([x.get_text().strip() for x in task_list])} | ||
================================== | ||
""" | ||
print(status_string, flush=True) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import json | ||
import time | ||
|
||
from agi.ExecutionAgent import SimpleExecutionAgent | ||
from agi.TaskManager import TaskManager | ||
from agi.utils import log_current_status | ||
|
||
|
||
def run_llama_agi(objective, initial_task, sleep_time): | ||
task_manager = TaskManager([initial_task]) | ||
execution_agent = SimpleExecutionAgent() | ||
|
||
# get initial list of tasks | ||
initial_completed_tasks_summary = task_manager.get_completed_tasks_summary() | ||
initial_task_prompt = initial_task + "\nReturn the list as an array." | ||
initial_task_list = execution_agent.execute_task( | ||
objective, initial_task_prompt, initial_completed_tasks_summary | ||
) | ||
initial_task_list = json.loads(initial_task_list) | ||
|
||
# add tasks to the task manager | ||
task_manager.add_new_tasks(initial_task_list) | ||
|
||
# prioritize initial tasks | ||
task_manager.prioritize_tasks(objective) | ||
|
||
while True: | ||
# Get the next task | ||
cur_task = task_manager.get_next_task() | ||
|
||
# Summarize completed tasks | ||
completed_tasks_summary = task_manager.get_completed_tasks_summary() | ||
|
||
# Execute current task | ||
result = execution_agent.execute_task(objective, cur_task, completed_tasks_summary) | ||
|
||
# store the task and result as completed | ||
task_manager.add_completed_task(cur_task, result) | ||
|
||
# generate new task(s), if needed | ||
task_manager.generate_new_tasks(objective, cur_task, result) | ||
|
||
# log state of AGI to terminal | ||
log_current_status(cur_task, result, completed_tasks_summary, task_manager.current_tasks) | ||
|
||
# Quit the loop? | ||
if len(task_manager.current_tasks) == 0: | ||
print("Out of tasks! Objective Accomplished?") | ||
break | ||
|
||
# wait a bit to let you read what's happening | ||
time.sleep(sleep_time) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(prog="Llama AGI", description="A baby-agi/auto-gpt inspired application, powered by Llama Index!") | ||
parser.add_argument("-it", "--initial-task", default="Create a list of tasks", help="The initial task for the system to carry out.") | ||
parser.add_argument("-o", "--objective", default="Solve world hunger", help="The overall objective for the system.") | ||
parser.add_argument('--sleep', default=2, help="Sleep time (in seconds) between each task loop.") | ||
args = parser.parse_args() | ||
|
||
run_llama_agi(args.objective, args.initial_task, args.sleep) |