Skip to content

Commit

Permalink
tot package
Browse files Browse the repository at this point in the history
  • Loading branch information
ysymyth committed Jul 4, 2023
1 parent 7382f24 commit 733b009
Show file tree
Hide file tree
Showing 33 changed files with 1,579 additions and 1,502 deletions.
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include src/tot/data/24/24.csv
include src/tot/data/crosswords/mini0505_0_100_5.json
include src/tot/data/crosswords/mini0505.json
include src/tot/data/text/data_100_random_text.txt
File renamed without changes
File renamed without changes
35 changes: 35 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools >= 61.0.0"]
build-backend = "setuptools.build_meta"

[project]
name = "tot"
version = "0.1.0"
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
readme = "README.md"
requires-python = ">= 3.7"
authors = [{ name = "Shunyu Yao", email = "[email protected]" }]
license = { text = "MIT License" }
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
]
dynamic=["dependencies"]


[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}

[tool.setuptools.packages.find]
where = ["src"] # list of folders that contain the packages (["."] by default)

[project.urls]
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"
4 changes: 2 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
<details>
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>

In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
</details>





![teaser](teaser.png)
![teaser](pics/teaser.png)

Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ sympy==1.12
tqdm==4.65.0
urllib3==2.0.2
yarl==1.9.2
pandas==2.0.3
105 changes: 7 additions & 98 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,18 @@
import os
import json
import itertools
import argparse
import numpy as np
from functools import partial
from models import gpt, gpt_usage
from tasks import get_task

def get_value(task, x, y, n_evaluate_sample, cache_value=True):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value

def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values

def get_votes(task, x, ys, n_evaluate_sample):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values

def get_proposals(task, x, y):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
return [y + _ + '\n' for _ in proposals]

def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
if prompt_sample == 'standard':
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = gpt(prompt, n=n_generate_sample, stop=stop)
return [y + _ for _ in samples]

def solve(args, task, idx, to_print=True):
print(gpt)
x = task.get_input(idx) # input
ys = [''] # current output candidates
infos = []
for step in range(task.steps):
# generation
if args.method_generate == 'sample':
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
elif args.method_generate == 'propose':
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == 'vote':
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
elif args.method_evaluate == 'value':
values = get_values(task, x, new_ys, args.n_evaluate_sample)

# selection
if args.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
elif args.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]

# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')

infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
ys = select_new_ys

if to_print:
print(ys)
return ys, {'steps': infos}

def naive_solve(args, task, idx, to_print=True):
x = task.get_input(idx) # input
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
return ys, {}
from tot.tasks import get_task
from tot.methods.bfs import solve, naive_solve
from tot.models import gpt_usage

def run(args):
task = get_task(args.task, args.task_file_path)
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
if args.naive_run:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
else:
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True)

for i in range(args.task_start_index, args.task_end_index):
Expand Down Expand Up @@ -136,7 +46,6 @@ def parse_args():
args.add_argument('--temperature', type=float, default=0.7)

args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
args.add_argument('--task_file_path', type=str, required=True)
args.add_argument('--task_start_index', type=int, default=900)
args.add_argument('--task_end_index', type=int, default=1000)

Expand All @@ -145,7 +54,7 @@ def parse_args():

args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
args.add_argument('--n_evaluate_sample', type=int, default=1)
args.add_argument('--n_select_sample', type=int, default=1)
Expand Down
1 change: 0 additions & 1 deletion scripts/crosswords/cot_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
python run.py \
--task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \
--task_end_index 20 \
--naive_run \
Expand Down
10 changes: 5 additions & 5 deletions scripts/crosswords/search_crosswords-dfs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"metadata": {},
"outputs": [],
"source": [
"cd ../.."
"cd .."
]
},
{
Expand All @@ -24,9 +24,9 @@
"outputs": [],
"source": [
"import json\n",
"from prompts.crosswords import propose_prompt, value_prompt\n",
"from models import gpt\n",
"from tasks.crosswords import MiniCrosswordsEnv\n",
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
"from tot.models import gpt\n",
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
"\n",
"env = MiniCrosswordsEnv()"
]
Expand Down Expand Up @@ -61,7 +61,7 @@
"source": [
"import re\n",
"import copy\n",
"from models import gpt\n",
"from tot.models import gpt\n",
"\n",
"def parse_line(input_str):\n",
" # regular expression pattern to match the input string format\n",
Expand Down
1 change: 0 additions & 1 deletion scripts/crosswords/standard_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
python run.py \
--task crosswords \
--task_file_path mini0505_0_100_5.json \
--task_start_index 0 \
--task_end_index 20 \
--naive_run \
Expand Down
1 change: 0 additions & 1 deletion scripts/game24/bfs.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--method_generate propose \
Expand Down
1 change: 0 additions & 1 deletion scripts/game24/cot_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \
Expand Down
1 change: 0 additions & 1 deletion scripts/game24/standard_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
python run.py \
--task game24 \
--task_file_path 24.csv \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \
Expand Down
3 changes: 1 addition & 2 deletions scripts/text/bfs.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--method_generate sample \
--method_evaluate vote \
--method_select greedy \
Expand Down
3 changes: 1 addition & 2 deletions scripts/text/cot_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--naive_run \
--prompt_sample cot \
--n_generate_sample 10 \
Expand Down
3 changes: 1 addition & 2 deletions scripts/text/standard_sampling.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
python run.py \
--task text \
--task_file_path data_100_random_text.txt \
--task_start_index 0 \
--task_end_index 1 \
--task_end_index 100 \
--naive_run \
--prompt_sample standard \
--n_generate_sample 10 \
Expand Down
37 changes: 37 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import setuptools

with open('README.md', 'r', encoding='utf-8') as fh:
long_description = fh.read()


setuptools.setup(
name='tot',
author='Shunyu Yao',
author_email='[email protected]',
description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"',
keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/princeton-nlp/tree-of-thought-llm',
project_urls={
'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm',
},
package_dir={'': 'src'},
packages=setuptools.find_packages(where='src'),
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
python_requires='>=3.7',
install_requires=[
'setuptools',
],
include_package_data=True,
)
1 change: 1 addition & 0 deletions src/tot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
Loading

0 comments on commit 733b009

Please sign in to comment.