Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dec branch commit #584

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fff7336
First Commit for SAT Solving
leyanpan Oct 14, 2023
087a633
Add updates for partial training
leyanpan Oct 19, 2023
1c2647e
Modify gitignore
leyanpan Oct 20, 2023
e9caadb
add large training set
leyanpan Nov 5, 2023
432ea20
add test for Large CDCL
leyanpan Nov 6, 2023
fe26dc3
add new prediction file
Nov 7, 2023
7260b24
add diff dataset
leyanpan Nov 7, 2023
209b732
Merge branch 'master' of github.com:leyanpan/nanoGPT_SAT
leyanpan Nov 7, 2023
3b40a81
add 20-layer model
Nov 10, 2023
983eeef
Fix prediction files
Nov 10, 2023
fa795f2
add LTL dataset
leyanpan Nov 29, 2023
e304737
Update Code for binary classification
leyanpan Jan 18, 2024
022e649
Update .gitignore and remove large files
leyanpan Jan 18, 2024
43ecb20
make server change
leyanpan Jan 18, 2024
7e80dad
Merge remote-tracking branch 'origin/main'
leyanpan Jan 18, 2024
32aeac3
Update code for classification evaluation
leyanpan Jan 23, 2024
02799e7
Merge branch 'main' of github.gatech.edu:LLM-Formal-Reasoning/nanoGPT…
leyanpan Jan 23, 2024
6596324
Windows Client Changes
leyanpan Jan 23, 2024
c2db691
Merge pull request #2 from LLM-Formal-Reasoning/main
leyanpan Jan 23, 2024
f62d523
Merge branch 'master' of github.gatech.edu:LLM-Formal-Reasoning/nanoG…
leyanpan Jan 23, 2024
d6e824d
Add debug option and debug logs
leyanpan Jan 27, 2024
01d2d3c
Update model
leyanpan Dec 20, 2024
ba3b143
Updates
leyanpan Jan 5, 2025
da59cf2
added requirements defined here - [200~Option 1 - Flag
cesposo Jan 5, 2025
6fd8ccc
added the mechanisms for reqs from December 18 meeting
cesposo Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions eval_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
compile = False # use PyTorch 2.0 to compile the model to be faster
eval = True
sep = ' ' # separator between tokens during decoding, default: nothing, i.e. join with empty string
class_only = False # only generate the class token
exec(open('configurator.py').read()) # overrides from command line or config file
if sep == 'SPACE':
sep = ' '
Expand Down Expand Up @@ -135,19 +136,24 @@ def line_sat(line, sep=' '):
continue
sample_cnt += 1
prompt = (torch.tensor(prompt, dtype=torch.long, device=device)[None, ...])
y = model.generate(prompt, max_new_tokens, temperature=temperature, top_k=top_k, stop=encode(['[SEP]']))
res_str = decode(y[0].tolist())
print(res_str)
print('---------------')
if eval:
if class_only:
true_label.append(label)
res = line_sat(res_str, sep)
if res is None:
res = not label
pred_label.append(res)
if output_file is not None:
with open(output_file, 'a', encoding='utf-8') as f:
f.write(res_str + '\n')
y = model.classify(prompt)
pred_label.append(y[0].item() == stoi['SAT'])
else:
y = model.generate(prompt, max_new_tokens, temperature=temperature, top_k=top_k, stop=encode(['[SEP]']))
res_str = decode(y[0].tolist())
print(res_str)
print('---------------')
if eval:
true_label.append(label)
res = line_sat(res_str, sep)
if res is None:
res = not label
pred_label.append(res)
if output_file is not None:
with open(output_file, 'a', encoding='utf-8') as f:
f.write(res_str + '\n')

if eval_labels is not None:
with open(eval_labels, 'r') as f:
Expand Down
9 changes: 9 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,12 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, stop=None):
break

return idx

@torch.no_grad()
def classify(self, idx):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and use the first position to predict a single token.
"""
logits, _ = self(idx)
logits = logits[:, 0, :]
return logits.argmax(dim=-1)