Skip to content

Commit

Permalink
Detect completed datasets/tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
SyphonArch committed Jul 18, 2024
1 parent 8be696c commit c4169bf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions any_precision/evaluate/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def evaluate_ppl(model, tokenizer, testcases, verbose=True, chunk_size=2048, tok
ppl = torch.exp(torch.stack(neg_log_likelihoods).mean())
logprint(verbose, f"Perplexity: {ppl.item()}")

results[f"{bit}-bit:{testcase_name}"] = ppl.item()
results[f"{testcase_name}:{bit}-bit"] = ppl.item()

if not is_anyprec:
break
Expand Down Expand Up @@ -261,7 +261,7 @@ def run_lm_eval(tokenizer, model, tasks, verbose=True):
logprint(verbose, json.dumps(eval_results['results'], indent=4))

for task in tasks:
results[f"{bit}-bit:{task}"] = eval_results['results'][task]
results[f"{task}:{bit}-bit"] = eval_results['results'][task]

if not is_anyprec:
break
Expand Down
21 changes: 16 additions & 5 deletions run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@
for model_path in model_paths:
model_name = os.path.basename(model_path)
model_jobs = {'to_print': [], 'ppl': [], 'lm-eval': []}
# This logic doesn't support dataset-level redoing for now, as it requires separate logic for Any-Precision models
datasets_with_results = [dataset for dataset in datasets if all_results.get(model_name, {}).get('ppl', {})]
tasks_with_results = [task for task in tasks if task in all_results.get(model_name, {}).get('lm-eval', {})]

# Check if all results already exist for any bit-width. If so, skip that dataset/task.
datasets_with_results = [testcase for testcase in datasets if
any(testcase == key.split(':')[0] for key in
all_results.get(model_name, {}).get('ppl', {}).keys())]
tasks_with_results = [task for task in tasks if
any(task == key.split(':')[0] for key in
all_results.get(model_name, {}).get('lm-eval', {}).keys())]
if not args.redo:
model_jobs['ppl'] = [testcase for testcase in datasets if testcase not in datasets_with_results]
model_jobs['lm-eval'] = [task for task in tasks if task not in tasks_with_results]
Expand Down Expand Up @@ -97,9 +102,15 @@


def save_results(results_dict):
def recursive_sort_dict(d):
if isinstance(d, dict):
return {k: recursive_sort_dict(v) for k, v in sorted(d.items())}
return d

sorted_results = recursive_sort_dict(results_dict)

with open(args.output_file, 'w') as f:
results_dict = dict(sorted(results_dict.items())) # sort by key
json.dump(results_dict, f, indent=2)
json.dump(sorted_results, f, indent=2)


# Run all tasks
Expand Down

0 comments on commit c4169bf

Please sign in to comment.