Skip to content

Commit

Permalink
[llama_agi] Lint + Typing (run-llama#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Apr 17, 2023
1 parent d234c86 commit 2cd8225
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 80 deletions.
30 changes: 19 additions & 11 deletions llama_agi/agi/ExecutionAgent.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.llms import OpenAI, BaseLLM
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from typing import Optional, Union

from agi.task_prompts import LC_EXECUTION_PROMPT


class SimpleExecutionAgent:
def __init__(self, llm=None, model_name="text-davinci-003"):
if llm:
def __init__(
self,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
model_name: str = "text-davinci-003",
) -> None:
if llm is not None:
self._llm = llm
elif model_name == "text-davinci-003":
self._llm = OpenAI(temperature=0, model_name=model_name, max_tokens=512)
else:
self._llm = ChatOpenAI(temperature=0, model_name=model_name, max_tokens=512 )
self._llm = ChatOpenAI(temperature=0, model_name=model_name, max_tokens=512)

self._prompt_template = PromptTemplate(
template=LC_EXECUTION_PROMPT,
input_variables=["task", "objective", "completed_tasks_summary"]
template=LC_EXECUTION_PROMPT,
input_variables=["task", "objective", "completed_tasks_summary"],
)
self._execution_chain = LLMChain(llm=self._llm, prompt=self._prompt_template)

def execute_task(self, objective, task, completed_tasks_summary):

def execute_task(
self, objective: str, task: str, completed_tasks_summary: str
) -> str:
result = self._execution_chain.predict(
objective=objective,
task=task,
completed_tasks_summary=completed_tasks_summary
objective=objective,
task=task,
completed_tasks_summary=completed_tasks_summary,
)
return result
85 changes: 57 additions & 28 deletions llama_agi/agi/TaskManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,69 @@

from agi.utils import initialize_task_list_index
from agi.task_prompts import (
DEFAULT_TASK_PRIORITIZE_TMPL,
DEFAULT_REFINE_TASK_PRIORITIZE_TMPL,
DEFAULT_TASK_CREATE_TMPL,
DEFAULT_TASK_PRIORITIZE_TMPL,
DEFAULT_REFINE_TASK_PRIORITIZE_TMPL,
DEFAULT_TASK_CREATE_TMPL,
DEFAULT_REFINE_TASK_CREATE_TMPL,
NO_COMPLETED_TASKS_SUMMARY
NO_COMPLETED_TASKS_SUMMARY,
)


class TaskManager:
def __init__(self, tasks: List[str]) -> None:
self.current_tasks = [Document(x) for x in tasks]
self.completed_tasks = []
self.current_tasks_index = initialize_task_list_index(self.current_tasks, index_path="current_tasks_index.json")
self.completed_tasks_index = initialize_task_list_index(self.completed_tasks, index_path="completed_tasks_index.json")

def _get_task_create_templates(self, prev_task: str, prev_result: str) -> Tuple[str, str]:
self.completed_tasks: List[Document] = []
self.current_tasks_index = initialize_task_list_index(
self.current_tasks, index_path="current_tasks_index.json"
)
self.completed_tasks_index = initialize_task_list_index(
self.completed_tasks, index_path="completed_tasks_index.json"
)

def _get_task_create_templates(
self, prev_task: str, prev_result: str
) -> Tuple[QuestionAnswerPrompt, RefinePrompt]:
text_qa_template = DEFAULT_TASK_CREATE_TMPL.format(
prev_result=prev_result, prev_task=prev_task, query_str='{query_str}', context_str="{context_str}"
prev_result=prev_result,
prev_task=prev_task,
query_str="{query_str}",
context_str="{context_str}",
)
text_qa_template = QuestionAnswerPrompt(text_qa_template)
llama_text_qa_template = QuestionAnswerPrompt(text_qa_template)

refine_template = DEFAULT_REFINE_TASK_CREATE_TMPL.format(
prev_result=prev_result, prev_task=prev_task, query_str='{query_str}', context_msg="{context_msg}", existing_answer='{existing_answer}'
prev_result=prev_result,
prev_task=prev_task,
query_str="{query_str}",
context_msg="{context_msg}",
existing_answer="{existing_answer}",
)
refine_template = RefinePrompt(refine_template)
llama_refine_template = RefinePrompt(refine_template)

return (text_qa_template, refine_template)
return (llama_text_qa_template, llama_refine_template)

def _get_task_prioritize_templates(self) -> Tuple[str, str]:
def _get_task_prioritize_templates(
self,
) -> Tuple[QuestionAnswerPrompt, RefinePrompt]:
return (
QuestionAnswerPrompt(DEFAULT_TASK_PRIORITIZE_TMPL),
RefinePrompt(DEFAULT_REFINE_TASK_PRIORITIZE_TMPL)
QuestionAnswerPrompt(DEFAULT_TASK_PRIORITIZE_TMPL),
RefinePrompt(DEFAULT_REFINE_TASK_PRIORITIZE_TMPL),
)

def get_completed_tasks_summary(self) -> str:
if len(self.completed_tasks) == 0:
return NO_COMPLETED_TASKS_SUMMARY
summary = self.completed_tasks_index.query("Summarize the current completed tasks", response_mode="tree_summarize")
summary = self.completed_tasks_index.query(
"Summarize the current completed tasks", response_mode="tree_summarize"
)
return str(summary)

def prioritize_tasks(self, objective: str) -> None:
(text_qa_template, refine_template) = self._get_task_prioritize_templates()
prioritized_tasks = self.current_tasks_index.query(
objective,
text_qa_template=text_qa_template,
refine_template=refine_template
objective,
text_qa_template=text_qa_template,
refine_template=refine_template,
)

new_tasks = []
Expand All @@ -62,15 +79,27 @@ def prioritize_tasks(self, objective: str) -> None:
self.current_tasks = [Document(x) for x in new_tasks]
self.current_tasks_index = initialize_task_list_index(self.current_tasks)

def generate_new_tasks(self, objective: str, prev_task: str, prev_result: str) -> None:
(text_qa_template, refine_template) = self._get_task_create_templates(prev_task, prev_result)
new_tasks = self.completed_tasks_index.query(objective, text_qa_template=text_qa_template, refine_template=refine_template)
def generate_new_tasks(
self, objective: str, prev_task: str, prev_result: str
) -> None:
(text_qa_template, refine_template) = self._get_task_create_templates(
prev_task, prev_result
)
new_tasks = self.completed_tasks_index.query(
objective,
text_qa_template=text_qa_template,
refine_template=refine_template,
)
try:
new_tasks = json.loads(str(new_tasks))
new_tasks = [x.strip() for x in new_tasks if len(x.strip()) > 10]
except:
new_tasks = str(new_tasks).split('\n')
new_tasks = [re.sub(r"^[0-9]+\.", "", x).strip() for x in str(new_tasks) if len(x.strip()) > 10 and x[0].isnumeric()]
except Exception:
new_tasks = str(new_tasks).split("\n")
new_tasks = [
re.sub(r"^[0-9]+\.", "", x).strip()
for x in str(new_tasks)
if len(x.strip()) > 10 and x[0].isnumeric()
]
self.add_new_tasks(new_tasks)

def get_next_task(self) -> str:
Expand All @@ -83,7 +112,7 @@ def add_new_tasks(self, tasks: List[str]) -> None:
if task not in self.current_tasks and task not in self.completed_tasks:
self.current_tasks.append(Document(task))
self.current_tasks_index = initialize_task_list_index(self.current_tasks)

def add_completed_task(self, task: str, result: str) -> None:
document = Document(f"Task: {task}\nResult: {result}\n")
self.completed_tasks.append(document)
Expand Down
12 changes: 2 additions & 10 deletions llama_agi/agi/task_prompts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate
)

#############################################
##### AGI Prefix #####
#############################################
Expand Down Expand Up @@ -73,7 +65,7 @@
"create new tasks to be completed that do not overlap with incomplete tasks. "
"Return the tasks as an array."
)
#TASK_CREATE_PROMPT = QuestionAnswerPrompt(DEFAULT_TASK_CREATE_TMPL)
# TASK_CREATE_PROMPT = QuestionAnswerPrompt(DEFAULT_TASK_CREATE_TMPL)

DEFAULT_REFINE_TASK_CREATE_TMPL = (
f"{PREFIX}"
Expand All @@ -88,7 +80,7 @@
"add new tasks to be completed that do not overlap with incomplete tasks. "
"Return the tasks as an array. If you have no more tasks to add, repeat the existing list of new tasks."
)
#REFINE_TASK_CREATE_PROMPT = RefinePrompt(DEFAULT_REFINE_TASK_CREATE_TMPL)
# REFINE_TASK_CREATE_PROMPT = RefinePrompt(DEFAULT_REFINE_TASK_CREATE_TMPL)


#############################################
Expand Down
33 changes: 16 additions & 17 deletions llama_agi/agi/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import os
import os
from llama_index import (
GPTSimpleVectorIndex,
GPTListIndex,
Document,
LLMPredictor,
ServiceContext
GPTListIndex,
ServiceContext,
)


def initialize_task_list_index(
documents,
llm_predictor=None,
embed_model=None,
prompt_helper=None,
index_path="./index.json",
chunk_size_limit=2000
):
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor,
embed_model=embed_model,
prompt_helper=prompt_helper,
chunk_size_limit=chunk_size_limit)
documents,
llm_predictor=None,
embed_model=None,
prompt_helper=None,
index_path="./index.json",
chunk_size_limit=2000,
):
service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor,
embed_model=embed_model,
prompt_helper=prompt_helper,
chunk_size_limit=chunk_size_limit,
)
if os.path.exists(index_path):
return GPTListIndex.load_from_disk(index_path, service_context=service_context)
else:
Expand Down
2 changes: 2 additions & 0 deletions llama_agi/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
langchain==0.0.141
llama-index==0.5.16
49 changes: 35 additions & 14 deletions llama_agi/run_llama_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,46 @@
from agi.utils import log_current_status


def run_llama_agi(objective, initial_task, sleep_time):
def run_llama_agi(objective: str, initial_task: str, sleep_time: int) -> None:
task_manager = TaskManager([initial_task])
execution_agent = SimpleExecutionAgent()

# get initial list of tasks
initial_completed_tasks_summary = task_manager.get_completed_tasks_summary()
initial_task_prompt = initial_task + "\nReturn the list as an array."
initial_task_list = execution_agent.execute_task(
initial_task_list_str = execution_agent.execute_task(
objective, initial_task_prompt, initial_completed_tasks_summary
)
initial_task_list = json.loads(initial_task_list)
initial_task_list = json.loads(initial_task_list_str)

# add tasks to the task manager
task_manager.add_new_tasks(initial_task_list)

# prioritize initial tasks
task_manager.prioritize_tasks(objective)

while True:
# Get the next task
cur_task = task_manager.get_next_task()

# Summarize completed tasks
completed_tasks_summary = task_manager.get_completed_tasks_summary()

# Execute current task
result = execution_agent.execute_task(objective, cur_task, completed_tasks_summary)

result = execution_agent.execute_task(
objective, cur_task, completed_tasks_summary
)

# store the task and result as completed
task_manager.add_completed_task(cur_task, result)

# generate new task(s), if needed
task_manager.generate_new_tasks(objective, cur_task, result)

# log state of AGI to terminal
log_current_status(cur_task, result, completed_tasks_summary, task_manager.current_tasks)
log_current_status(
cur_task, result, completed_tasks_summary, task_manager.current_tasks
)

# Quit the loop?
if len(task_manager.current_tasks) == 0:
Expand All @@ -54,12 +58,29 @@ def run_llama_agi(objective, initial_task, sleep_time):


if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="Llama AGI", description="A baby-agi/auto-gpt inspired application, powered by Llama Index!")
parser.add_argument("-it", "--initial-task", default="Create a list of tasks", help="The initial task for the system to carry out. Default='Create a list of tasks'")
parser.add_argument("-o", "--objective", default="Solve world hunger", help="The overall objective for the system. Default='Solve world hunger'")
parser.add_argument('--sleep', default=2, help="Sleep time (in seconds) between each task loop. Default=2", type=int)
parser = argparse.ArgumentParser(
prog="Llama AGI",
description="A baby-agi/auto-gpt inspired application, powered by Llama Index!",
)
parser.add_argument(
"-it",
"--initial-task",
default="Create a list of tasks",
help="The initial task for the system to carry out. Default='Create a list of tasks'",
)
parser.add_argument(
"-o",
"--objective",
default="Solve world hunger",
help="The overall objective for the system. Default='Solve world hunger'",
)
parser.add_argument(
"--sleep",
default=2,
help="Sleep time (in seconds) between each task loop. Default=2",
type=int,
)

args = parser.parse_args()

run_llama_agi(args.objective, args.initial_task, args.sleep)

0 comments on commit 2cd8225

Please sign in to comment.