Skip to content

Commit

Permalink
factor out model
Browse files Browse the repository at this point in the history
  • Loading branch information
conglu1997 committed Oct 19, 2024
1 parent b1f789b commit f7c62fa
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 369 deletions.
111 changes: 24 additions & 87 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import backoff
import json
import os
import os.path as osp
import requests
import time
from typing import List, Dict, Union
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers

import requests
import backoff
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS

S2_API_KEY = os.getenv("S2_API_KEY")

Expand Down Expand Up @@ -73,12 +73,12 @@

# GENERATE IDEAS
def generate_ideas(
base_dir,
client,
model,
skip_generation=False,
max_num_generations=20,
num_reflections=5,
base_dir,
client,
model,
skip_generation=False,
max_num_generations=20,
num_reflections=5,
):
if skip_generation:
# Load existing ideas from file
Expand Down Expand Up @@ -149,7 +149,7 @@ def generate_ideas(
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)

Expand All @@ -175,12 +175,12 @@ def generate_ideas(

# GENERATE IDEAS OPEN-ENDED
def generate_next_idea(
base_dir,
client,
model,
prev_idea_archive=[],
num_reflections=5,
max_attempts=10,
base_dir,
client,
model,
prev_idea_archive=[],
num_reflections=5,
max_attempts=10,
):
idea_archive = prev_idea_archive
original_archive_size = len(idea_archive)
Expand Down Expand Up @@ -248,7 +248,7 @@ def generate_next_idea(
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)

Expand Down Expand Up @@ -358,11 +358,11 @@ def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:


def check_idea_novelty(
ideas,
base_dir,
client,
model,
max_num_iterations=10,
ideas,
base_dir,
client,
model,
max_num_iterations=10,
):
with open(osp.join(base_dir, "experiment.py"), "r") as f:
code = f.read()
Expand Down Expand Up @@ -463,24 +463,7 @@ def check_idea_novelty(
"--model",
type=str,
default="gpt-4o-2024-05-13",
choices=[
"claude-3-5-sonnet-20240620",
"gpt-4o-mini-2024-07-18",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"deepseek-coder-v2-0724",
"llama3.1-405b",
# Anthropic Claude models via Amazon Bedrock
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
# Anthropic Claude models Vertex AI
"vertex_ai/claude-3-opus@20240229",
"vertex_ai/claude-3-5-sonnet@20240620",
"vertex_ai/claude-3-sonnet@20240229",
"vertex_ai/claude-3-haiku@20240307",
],
choices=AVAILABLE_LLMS,
help="Model to use for AI Scientist.",
)
parser.add_argument(
Expand All @@ -496,53 +479,7 @@ def check_idea_novelty(
args = parser.parse_args()

# Create client
if args.model == "claude-3-5-sonnet-20240620":
import anthropic

print(f"Using Anthropic API with model {args.model}.")
client_model = "claude-3-5-sonnet-20240620"
client = anthropic.Anthropic()
elif args.model.startswith("bedrock") and "claude" in args.model:
import anthropic

# Expects: bedrock/<MODEL_ID>
client_model = args.model.split("/")[-1]

print(f"Using Amazon Bedrock with model {client_model}.")
client = anthropic.AnthropicBedrock()
elif args.model.startswith("vertex_ai") and "claude" in args.model:
import anthropic

# Expects: vertex_ai/<MODEL_ID>
client_model = args.model.split("/")[-1]

print(f"Using Vertex AI with model {client_model}.")
client = anthropic.AnthropicVertex()
elif 'gpt' in args.model:
import openai

print(f"Using OpenAI API with model {args.model}.")
client_model = args.model
client = openai.OpenAI()
elif args.model == "deepseek-coder-v2-0724":
import openai

print(f"Using OpenAI API with {args.model}.")
client_model = "deepseek-coder-v2-0724"
client = openai.OpenAI(
api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com"
)
elif args.model == "llama3.1-405b":
import openai

print(f"Using OpenAI API with {args.model}.")
client_model = "meta-llama/llama-3.1-405b-instruct"
client = openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
else:
raise ValueError(f"Model {args.model} not supported.")
client, client_model = create_client(args.model)

base_dir = osp.join("templates", args.experiment)
results_dir = osp.join("results", args.experiment)
Expand Down
84 changes: 67 additions & 17 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
import backoff
import openai
import json
import openai
import re


MAX_NUM_TOKENS = 4096

AVAILABLE_LLMS = [
"claude-3-5-sonnet-20240620",
"gpt-4o-mini-2024-07-18",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"deepseek-coder-v2-0724",
"llama3.1-405b",
# Anthropic Claude models via Amazon Bedrock
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
# Anthropic Claude models Vertex AI
"vertex_ai/claude-3-opus@20240229",
"vertex_ai/claude-3-5-sonnet@20240620",
"vertex_ai/claude-3-sonnet@20240229",
"vertex_ai/claude-3-haiku@20240307",
]


# Get N responses from a single message, used for ensembling.
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
def get_batch_responses_from_llm(
msg,
client,
model,
system_message,
print_debug=False,
msg_history=None,
temperature=0.75,
n_responses=1,
msg,
client,
model,
system_message,
print_debug=False,
msg_history=None,
temperature=0.75,
n_responses=1,
):
if msg_history is None:
msg_history = []
Expand Down Expand Up @@ -107,13 +126,13 @@ def get_batch_responses_from_llm(

@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
def get_response_from_llm(
msg,
client,
model,
system_message,
print_debug=False,
msg_history=None,
temperature=0.75,
msg,
client,
model,
system_message,
print_debug=False,
msg_history=None,
temperature=0.75,
):
if msg_history is None:
msg_history = []
Expand Down Expand Up @@ -240,3 +259,34 @@ def extract_json_between_markers(llm_output):
continue # Try next match

return None # No valid JSON found


def create_client(model):
if model == "claude-3-5-sonnet-20240620":
print(f"Using Anthropic API with model {model}.")
return anthropic.Anthropic(), model
elif model.startswith("bedrock") and "claude" in model:
client_model = model.split("/")[-1]
print(f"Using Amazon Bedrock with model {client_model}.")
return anthropic.AnthropicBedrock(), client_model
elif model.startswith("vertex_ai") and "claude" in model:
client_model = model.split("/")[-1]
print(f"Using Vertex AI with model {client_model}.")
return anthropic.AnthropicVertex(), client_model
elif 'gpt' in model:
print(f"Using OpenAI API with model {model}.")
return openai.OpenAI(), model
elif model == "deepseek-coder-v2-0724":
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com"
), model
elif model == "llama3.1-405b":
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1"
), "meta-llama/llama-3.1-405b-instruct"
else:
raise ValueError(f"Model {model} not supported.")
8 changes: 4 additions & 4 deletions ai_scientist/perform_experiments.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import shutil
import json
import os.path as osp
import shutil
import subprocess
from subprocess import TimeoutExpired
import sys
import json
from subprocess import TimeoutExpired

MAX_ITERS = 4
MAX_RUNS = 5
Expand Down Expand Up @@ -151,7 +151,7 @@ def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
We will be running the command `python plot.py` to generate the plots.
"""
while True:
coder_out = coder.run(next_prompt)
_ = coder.run(next_prompt)
return_code, next_prompt = run_plotting(folder_name)
current_iter += 1
if return_code == 0 or current_iter >= MAX_ITERS:
Expand Down
Loading

0 comments on commit f7c62fa

Please sign in to comment.