Skip to content

Commit

Permalink
Enable record_processor return multiple plans to let user choose
Browse files Browse the repository at this point in the history
  • Loading branch information
yunhao0204 committed Apr 1, 2024
1 parent 73e91cf commit 1c7ad6f
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 51 deletions.
35 changes: 21 additions & 14 deletions record_processor/record_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .parser.psr_record_parser import PSRRecordParser
from .utils import create_folder, save_to_json, unzip_and_read_file
from ufo.utils import print_with_color
from typing import Tuple


configs = load_config()

Expand All @@ -33,15 +35,17 @@ def main():
record.set_request(parsed_args.request[0])

summarizer = DemonstrationSummarizer(
configs["ACTION_AGENT"]["VISUAL_MODE"], configs["DEMONSTRATION_PROMPT"], configs["ACTION_SELECTION_EXAMPLE_PROMPT"], configs["API_PROMPT"])
configs["ACTION_AGENT"]["VISUAL_MODE"], configs["DEMONSTRATION_PROMPT"], configs["ACTION_SELECTION_EXAMPLE_PROMPT"], configs["API_PROMPT"], configs["RAG_DEMONSTRATION_COMPLETION_N"])

summaries, total_cost = summarizer.get_summary_list([record])
if asker(summaries):
summaries, total_cost = summarizer.get_summary_list(record)

is_save, index = asker(summaries)
if is_save and index >= 0:
demonstration_path = configs["DEMONSTRATION_SAVED_PATH"]
create_folder(demonstration_path)

save_to_json(record.__dict__, os.path.join(demonstration_path, "demonstration_log", parsed_args.request[0].replace(' ', '_')) + ".json")
summarizer.create_or_update_yaml(summaries, os.path.join(demonstration_path, "demonstration.yaml"))
summarizer.create_or_update_yaml([summaries[index]], os.path.join(demonstration_path, "demonstration.yaml"))
summarizer.create_or_update_vector_db(summaries, os.path.join(demonstration_path, "demonstration_db"))

formatted_cost = '${:.2f}'.format(total_cost)
Expand All @@ -51,16 +55,19 @@ def main():
print_with_color(str(e), "red")


def asker(summaries) -> bool:
plan = summaries[0]["example"]["Plan"]
print_with_color("""Here's the plan summarized from your demonstration: """, "cyan")
print_with_color(plan, "green")
print_with_color("""Would you like to save the plan future reference by the agent?
[Y] for yes, any other key for no.""", "cyan")
def asker(summaries) -> Tuple[bool, int]:
print_with_color("""Here are the plans summarized from your demonstration: """, "cyan")
for index, summary in enumerate(summaries):
print_with_color(f"Plan [{index + 1}]", "green")
print_with_color(f"{summary['example']['Plan']}", "yellow")

print_with_color(f"Would you like to save any one of them as future reference by the agent? press ", color="cyan" , end="")
for index in range(1, len(summaries) + 1):
print_with_color(f"[{index}]", color="cyan", end=" ")
print_with_color("to save the corresponding plan, or press any other key to skip.", color="cyan")
response = input()

if response.upper() == "Y":
return True
if response.isnumeric() and int(response) in range(1, len(summaries) + 1):
return True, int(response) - 1
else:
return False
return False, -1
43 changes: 18 additions & 25 deletions record_processor/summarizer/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml
from record_processor.parser.demonstration_record import DemonstrationRecord
from record_processor.utils import json_parser
from ufo.llm.llm_call import get_completion
from ufo.llm.llm_call import get_completions
from ufo.prompter.demonstration_prompter import DemonstrationPrompter
from typing import Tuple
from langchain.docstore.document import Document
Expand All @@ -18,7 +18,7 @@ class DemonstrationSummarizer:
The DemonstrationSummarizer class is the summarizer for the demonstration learning.
"""

def __init__(self, is_visual: bool, prompt_template: str, demonstration_prompt_template: str, api_prompt_template: str):
def __init__(self, is_visual: bool, prompt_template: str, demonstration_prompt_template: str, api_prompt_template: str, completion_num: int = 1):
"""
Initialize the DemonstrationSummarizer.
:param is_visual: Whether the request is for visual model.
Expand All @@ -30,6 +30,7 @@ def __init__(self, is_visual: bool, prompt_template: str, demonstration_prompt_t
self.prompt_template = prompt_template
self.demonstration_prompt_template = demonstration_prompt_template
self.api_prompt_template = api_prompt_template
self.completion_num = completion_num

def build_prompt(self, demo_record: DemonstrationRecord) -> list:
"""
Expand All @@ -47,16 +48,8 @@ def build_prompt(self, demo_record: DemonstrationRecord) -> list:

return demonstration_prompt

def get_summary(self, prompt_message: list) -> Tuple[dict, float]:
"""
Get the summary.
:param prompt_message: A list of prompt messages.
return: The summary and the cost.
"""

# Get the completion for the prompt message
response_string, cost = get_completion(
prompt_message, "ACTION", use_backup_engine=True)

def restructure_response(self, response_string: str) -> dict:
try:
response_json = json_parser(response_string)
except:
Expand All @@ -69,26 +62,27 @@ def get_summary(self, prompt_message: list) -> Tuple[dict, float]:
for key in ["Observation", "Thought", "ControlLabel", "ControlText", "Function", "Args", "Status", "Plan", "Comment"]:
summary["example"][key] = response_json.get(key, "")
summary["Tips"] = response_json.get("Tips", "")

return summary

return summary, cost

def get_summary_list(self, records: list) -> Tuple[list, float]:
def get_summary_list(self, record: DemonstrationRecord) -> Tuple[list, float]:
"""
Get the summary list for a list of records.
:param records: The list of records.
return: The summary list and the total cost.
Get the summary list for a record
:param record: The demonstration record.
return: The summary list for the user defined completion number and the cost
"""

prompt = self.build_prompt(record)
response_string_list, cost = get_completions(prompt, "ACTION", use_backup_engine=True, n=self.completion_num)
summaries = []
total_cost = 0
for record in records:
prompt = self.build_prompt(record)
summary, cost = self.get_summary(prompt)
for response_string in response_string_list:
summary = self.restructure_response(response_string)
summary["request"] = record.get_request()
summary["app_list"] = record.get_applications()
summaries.append(summary)
total_cost += cost

return summaries, total_cost
return summaries, cost

@staticmethod
def create_or_update_yaml(summaries: list, yaml_path: str):
Expand Down Expand Up @@ -145,8 +139,7 @@ def create_or_update_vector_db(summaries: list, db_path: str):

# Check if the db exists, if not, create a new one.
if os.path.exists(db_path):
prev_db = FAISS.load_local(
db_path, embeddings)
prev_db = FAISS.load_local(db_path, embeddings)
db.merge_from(prev_db)

db.save_local(db_path)
Expand Down
3 changes: 1 addition & 2 deletions ufo/config/config.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,5 @@ RAG_EXPERIENCE_RETRIEVED_TOPK: 5 # The topk for the offline retrieved documents
## RAG Configuration for demonstration
RAG_DEMONSTRATION: True # Whether to use the RAG from its user demonstration.
RAG_DEMONSTRATION_RETRIEVED_TOPK: 5 # The topk for the offline retrieved documents


RAG_DEMONSTRATION_COMPLETION_N: 3 # The number of completion choices for the demonstration result

2 changes: 1 addition & 1 deletion ufo/config/config_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CONTROL_BACKEND: "uia" # The backend for control action
MAX_STEP: 30 # The max step limit for completing the user request
SLEEP_TIME: 5 # The sleep time between each step to wait for the window to be ready
SAFE_GUARD: True # Whether to use the safe guard to prevent the model from doing sensitve operations.
CONTROL_TYPE_LIST: ["Button", "Edit", "TabItem", "Document", "ListItem", "MenuItem", "ScrollBar", "TreeItem", "Hyperlink", "ComboBox", "RadioButton"] # The list of control types that are allowed to be selected
CONTROL_TYPE_LIST: ["Button", "Edit", "TabItem", "Document", "ListItem", "MenuItem", "ScrollBar", "TreeItem", "Hyperlink", "ComboBox", "RadioButton", "DataItem"] # The list of control types that are allowed to be selected
HISTORY_KEYS: ["Step", "Thought", "ControlText", "Action", "Comment", "Results"] # The keys of the action history for the next step.
ANNOTATION_COLORS: {
"Button": "#FFF68F",
Expand Down
29 changes: 24 additions & 5 deletions ufo/llm/llm_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,42 @@

from ufo.utils import print_with_color
from ..config.config import load_config
from typing import Tuple


configs = load_config()


def get_completion(messages, agent: str='APP', use_backup_engine: bool=True):
def get_completion(messages, agent: str='APP', use_backup_engine: bool=True) -> Tuple[str, float]:
"""
Get completion for the given messages.
Args:
messages (list): List of messages to be used for completion.
agent (str, optional): Type of agent. Possible values are 'APP', 'ACTION' or 'BACKUP'.
use_backup_engine (bool, optional): Flag indicating whether to use the backup engine or not.
Returns:
tuple: A tuple containing the completion response (str) and the cost (float).
"""

responses, cost = get_completions(messages, agent=agent, use_backup_engine=use_backup_engine, n=1)
return responses[0], cost


def get_completions(messages, agent: str='APP', use_backup_engine: bool=True, n: int=1) -> Tuple[list, float]:
"""
Get completions for the given messages.
Args:
messages (list): List of messages to be used for completion.
agent (str, optional): Type of agent. Possible values are 'APP', 'ACTION' or 'BACKUP'.
use_backup_engine (bool, optional): Flag indicating whether to use the backup engine or not.
n (int, optional): Number of completions to generate.
Returns:
tuple: A tuple containing the completion responses (list of str) and the cost (float).
"""
if agent.lower() == "app":
agent_type = "APP_AGENT"
Expand All @@ -34,14 +53,14 @@ def get_completion(messages, agent: str='APP', use_backup_engine: bool=True):
try:
if api_type.lower() in ['openai', 'aoai', 'azure_ad']:
from .openai import OpenAIService
response, cost = OpenAIService(configs, agent_type=agent_type).chat_completion(messages)
response, cost = OpenAIService(configs, agent_type=agent_type).chat_completion(messages, n)
return response, cost
else:
raise ValueError(f'API_TYPE {api_type} not supported')
except Exception as e:
if use_backup_engine:
print_with_color(f"The API request of {agent_type} failed: {e}.", "red")
print_with_color(f"Switching to use the backup engine...", "yellow")
return get_completion(messages, agent='backup', use_backup_engine=False)
return get_completion(messages, agent='backup', use_backup_engine=False, n=n)
else:
raise e
8 changes: 6 additions & 2 deletions ufo/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def __init__(self, config, agent_type: str):
def chat_completion(
self,
messages,
n,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs: Any,
):
) :

model = self.config_llm["API_MODEL"]

temperature = temperature if temperature is not None else self.config["TEMPERATURE"]
Expand All @@ -54,6 +56,7 @@ def chat_completion(
max_tokens=max_tokens,
top_p=top_p,
stream=stream,
n=n,
**kwargs
)

Expand All @@ -63,7 +66,8 @@ def chat_completion(

cost = prompt_tokens / 1000 * 0.01 + completion_tokens / 1000 * 0.03

return response.choices[0].message.content, cost
return [response.choices[i].message.content for i in range(n)], cost


except openai.APITimeoutError as e:
# Handle timeout error, e.g. retry or log
Expand Down
4 changes: 2 additions & 2 deletions ufo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# init colorama
init()

def print_with_color(text: str, color: str = ""):
def print_with_color(text: str, color: str = "", end: str = "\n"):
"""
Print text with specified color using ANSI escape codes from Colorama library.
Expand All @@ -34,7 +34,7 @@ def print_with_color(text: str, color: str = ""):
selected_color = color_mapping.get(color.lower(), "")
colored_text = selected_color + text + Style.RESET_ALL

print(colored_text)
print(colored_text, end=end)



Expand Down

0 comments on commit 1c7ad6f

Please sign in to comment.