Skip to content

Commit

Permalink
Export function from Campaigns to utils_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
egonrian committed Sep 21, 2023
1 parent 6ce5633 commit 4d2f02e
Showing 1 changed file with 44 additions and 52 deletions.
96 changes: 44 additions & 52 deletions app/pages/0_Campaigns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@
"""

import asyncio
import functools
import streamlit as st
import random
import time
import tomllib
import utils_workspace

from google.oauth2 import service_account
from vertexai.preview.language_models import TextGenerationModel
from utils_campaign import add_new_campaign
from utils_prompt import async_predict_text_llm


# Load configuration file
with open("./app_config.toml", "rb") as f:
data = tomllib.load(f)

page_cfg = data["pages"]["campaigns"]
st.set_page_config(
page_title=data["pages"]["campaigns"]["page_title"],
page_icon=data["pages"]["campaigns"]["page_icon"]
page_title=page_cfg["page_title"],
page_icon=page_cfg["page_icon"]
)

import utils_styles
utils_styles.sidebar_apply_style(
style=utils_styles.style_sidebar,
image_path=data["pages"]["campaigns"]["sidebar_image_path"]
image_path=page_cfg["sidebar_image_path"]
)

# Campaigns unique key
CAMPAIGNS_KEY = data["pages"]["campaigns"]["campaigns_key"]
CAMPAIGNS_KEY = page_cfg["campaigns_key"]

# Set project parameters
PROJECT_ID = data["global"]["project_id"]
Expand All @@ -63,16 +63,16 @@
THEMES_FOR_PROMPTS_KEY = f'{PAGE_KEY_PREFIX}_theme'

# Prompt templates
BRAND_OVERVIEW = data["pages"]["campaigns"].get("prompt_brand_overview", "")
BRAND_OVERVIEW = page_cfg.get("prompt_brand_overview", "")

BRAND_STATEMENT_PROMPT_TEMPLATE = data["pages"]["campaigns"]["prompt_brand_statement_template"]
PRIMARY_MSG_PROMPT_TEMPLATE = data["pages"]["campaigns"]["prompt_primary_msg_template"]
COMMS_CHANNEL_PROMPT_TEMPLATE = data["pages"]["campaigns"]["prompt_comms_channel_template"]
BUSINESS_NAME = data["pages"]["campaigns"]["prompt_business_name"]
GENDER_FOR_PROMPTS = data["pages"]["campaigns"]["prompt_genders"]
AGEGROUP_FOR_PROMPTS = data["pages"]["campaigns"]["prompt_age_groups"]
OBJECTIVES_FOR_PROMPTS = data["pages"]["campaigns"]["prompt_objectives"]
COMPETITORS_FOR_PROMPTS = data["pages"]["campaigns"]["prompt_competitors"]
BRAND_STATEMENT_PROMPT_TEMPLATE = page_cfg["prompt_brand_statement_template"]
PRIMARY_MSG_PROMPT_TEMPLATE = page_cfg["prompt_primary_msg_template"]
COMMS_CHANNEL_PROMPT_TEMPLATE = page_cfg["prompt_comms_channel_template"]
BUSINESS_NAME = page_cfg["prompt_business_name"]
GENDER_FOR_PROMPTS = page_cfg["prompt_genders"]
AGEGROUP_FOR_PROMPTS = page_cfg["prompt_age_groups"]
OBJECTIVES_FOR_PROMPTS = page_cfg["prompt_objectives"]
COMPETITORS_FOR_PROMPTS = page_cfg["prompt_competitors"]

# Variables for Workspace integration
SCOPES = data["pages"]["12_review_activate"]["workspace_scopes"]
Expand All @@ -87,9 +87,9 @@

cols = st.columns([13, 87])
with cols[0]:
st.image(data["pages"]["campaigns"]["page_title_icon"])
st.image(page_cfg["page_title_icon"])
with cols[1]:
st.title(data["pages"]["campaigns"]["page_title"])
st.title(page_cfg["page_title"])

tab1, tab2 = st.tabs(["Create Campaign", "Existing Campaigns"])

Expand Down Expand Up @@ -134,57 +134,48 @@

campaigns = st.session_state[CAMPAIGNS_KEY].values()
campaign_names = {
campaign.name : str(campaign.unique_uuid) for campaign in campaigns
campaign.name : str(
campaign.unique_uuid) for campaign in campaigns
}

if campaign_name in campaign_names:
st.info(f"Campaign with name '{campaign_name}' already created. "
st.info(f"Campaign with name '{campaign_name}' "
"already created. "
"Provide a unique name.")
else:
is_allowed_to_create_campaign = True
else:
is_allowed_to_create_campaign = True

if is_allowed_to_create_campaign:
async def async_predict(prompt: str, name: str)-> str:
loop = asyncio.get_running_loop()
with st.spinner(f"Generating {name}"):
generated_response = await loop.run_in_executor(
None,
functools.partial(
llm.predict,
prompt=prompt,
temperature=0.2,
max_output_tokens=1024,
top_k=40, top_p=0.8))
if generated_response and generated_response.text:
return generated_response.text
return ""
async def generate_campaign() -> tuple:
return await asyncio.gather(
async_predict(
async_predict_text_llm(
BRAND_STATEMENT_PROMPT_TEMPLATE.format(
gender_select_theme,
age_select_theme,
objective_select_theme,
competitor_select_theme,
BRAND_OVERVIEW),
"Brand Statement"),
async_predict(
"Brand Statement",
TEXT_MODEL_NAME),
async_predict_text_llm(
PRIMARY_MSG_PROMPT_TEMPLATE.format(
gender_select_theme,
age_select_theme,
objective_select_theme,
competitor_select_theme,
BRAND_OVERVIEW),
"Brand Strategy"),
async_predict(
"Brand Strategy",
TEXT_MODEL_NAME),
async_predict_text_llm(
COMMS_CHANNEL_PROMPT_TEMPLATE.format(
gender_select_theme,
age_select_theme,
objective_select_theme,
competitor_select_theme),
"Communication channels"))
"Communication channels",
TEXT_MODEL_NAME))
try:
generated_tuple = asyncio.run(generate_campaign())
st.session_state[BRAND_STATEMENT_KEY] = generated_tuple[0]
Expand All @@ -210,6 +201,8 @@ async def generate_campaign() -> tuple:
'primary_message': st.session_state[PRIMARY_MSG_KEY],
'comm_channels': st.session_state[COMMS_CHANNEL_KEY]
}
# local reference
brief = st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief

with st.spinner("Creating Google Drive folder..."):
new_folder_id = utils_workspace.create_folder_in_folder(
Expand All @@ -227,25 +220,24 @@ async def generate_campaign() -> tuple:
copy_title=f"GenAI Marketing Brief",
credentials=CREDENTIALS)
st.session_state[DOC_ID_KEY] = doc_id
clean = lambda x: x.replace("*", "").strip()
utils_workspace.update_doc(
document_id=doc_id,
campaign_name= campaign_name,
business_name=st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief["business_name"].replace(
"*", "").strip(),
scenario=st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief["brief_scenario"].replace(
"*", "").strip(),
brand_statement=st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief["brand_statement"].replace(
"*", "").strip(),
primary_msg=st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief["primary_message"].replace(
"*", "").strip(),
comms_channel=st.session_state[CAMPAIGNS_KEY][campaign_uuid].brief["comm_channels"].replace(
"*", "").strip(),
business_name=clean(brief["business_name"]),
scenario=clean(brief["brief_scenario"]),
brand_statement=clean(brief["brand_statement"]),
primary_msg=clean(brief["primary_message"]),
comms_channel=clean(brief["comm_channels"]),
credentials=CREDENTIALS)
st.success("Brief document uploaded to Google Docs.")

st.session_state[CAMPAIGNS_KEY][campaign_uuid].workspace_assets = {}
st.session_state[CAMPAIGNS_KEY][campaign_uuid].workspace_assets['brief_docs_id']=doc_id
st.session_state[CAMPAIGNS_KEY][campaign_uuid].workspace_assets['folder_id']=new_folder_id
st.session_state[CAMPAIGNS_KEY][
campaign_uuid].workspace_assets = {}
st.session_state[CAMPAIGNS_KEY][
campaign_uuid].workspace_assets['brief_docs_id']=doc_id
st.session_state[CAMPAIGNS_KEY][
campaign_uuid].workspace_assets['folder_id']=new_folder_id


with tab2:
Expand Down

0 comments on commit 4d2f02e

Please sign in to comment.