Skip to content

Commit

Permalink
QA ADDED
Browse files Browse the repository at this point in the history
  • Loading branch information
n.semenov committed Aug 17, 2024
1 parent 92f2c6e commit 2d55476
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 44 deletions.
36 changes: 1 addition & 35 deletions graphs/parent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,7 @@ def __init__(self, model, system_prompt, api_key):
api_key=api_key,
)

def generate(self, prompt, jsn = False, t = 0.7):
if jsn:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "system",
"content": self.system_prompt,
},
{
"role": "user",
"content": prompt,
}
],
model=self.model,
response_format={"type": "json_object"},
temperature=t
)
else:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "system",
"content": self.system_prompt,
},
{
"role": "user",
"content": prompt,
}
],
model=self.model,
temperature=t
)
response = chat_completion.choices[0].message.content
prompt_tokens = chat_completion.usage.prompt_tokens
completion_tokens = chat_completion.usage.completion_tokens

def generate(self, prompt, jsn = False, t = 0.7):
if jsn:
chat_completion = self.client.chat.completions.create(
Expand Down
24 changes: 15 additions & 9 deletions musique_test_big.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@


log_path = "MusiqueTestGPTmini"
data_path = 'musique_ans_v1.0_dev.jsonl'
# musique | hotpotqa
task_name = "musique"
topk_episodic = 2
graph_model, qa_model = "gpt-4o-mini", "gpt-4o-mini"
log = Logger(log_path)

def run():
tasks = get_data(data_path)
tasks = get_data(task_name)
agent_items, agent_qa, graph = load_setup(graph_model, qa_model)
trueP, pred_len, true_len, EM = [], [], [], []

Expand Down Expand Up @@ -50,14 +51,19 @@ def run():



def get_data(fiename):
with open(fiename, 'r') as json_file:
json_list = list(json_file)
def get_data(task_name):
if task_name == "musique":
with open('qa_data/musique_ans_v1.0_dev.jsonl', 'r') as json_file:
json_list = list(json_file)

tasks = []
for json_str in json_list:
result = json.loads(json_str)
tasks.append(result)
tasks = []
for json_str in json_list:
result = json.loads(json_str)
tasks.append(result)
if task_name == "hotpotqa":
with open('qa_data/hotpot_dev_distractor_v1.json', 'r') as inp:
data = json.load(inp)
tasks = [" ".join(task["context"][-1]) for task in data]
ids = np.random.RandomState(seed=42).permutation(len(tasks))[:200]
tasks = [tasks[i] for i in ids]

Expand Down
1 change: 1 addition & 0 deletions qa_data/hotpot_dev_distractor_v1.json

Large diffs are not rendered by default.

0 comments on commit 2d55476

Please sign in to comment.