Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lsjlsj35 authored Jan 24, 2024
1 parent df63033 commit fced8c3
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions eval/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def print_section(section_name, color):
print(colored(formatted_section, color))


def get_paths(basedir, prefix, suffix, dataset_name, model_name):
def get_paths(basedir, prefix, suffix, dataset_name, model_name, unique_dir=False):
if unique_dir:
return (
os.path.join(basedir, "datasets", f"{prefix}_{dataset_name}.{suffix}"),
os.path.join(basedir, "responses", model_name, f"{prefix}_{dataset_name}.{suffix}"),
os.path.join(basedir, "results", model_name, f"{prefix}_{dataset_name}.csv"),
)
return (
os.path.join(basedir, "datasets", f"{prefix}_{dataset_name}.{suffix}"),
os.path.join(basedir, "responses", f"{prefix}_{dataset_name}.{suffix}"),
Expand Down Expand Up @@ -78,19 +84,23 @@ def generate_and_evaluate(
evaluator_klasses,
evaluator_kwargses,
max_iter,
eval_only=False,
generate_only=False,
):
sec_formatter = f"[{task_name}] {{}} RESPONSES FOR {dataset_name.upper()}"

# Generate responses
print_section(sec_formatter.format("GENERATING"), "cyan")
generator = generator_klass(**generator_kwargs)
generate_until_completed(generator, max_iter=max_iter)
if not eval_only:
print_section(sec_formatter.format("GENERATING"), "cyan")
generator = generator_klass(**generator_kwargs)
generate_until_completed(generator, max_iter=max_iter)

# Evaluate responses
assert len(evaluator_klasses) == len(evaluator_kwargses)
avg_scores = {}
for evaluator_klass, evaluator_kwargs in zip(evaluator_klasses, evaluator_kwargses):
print_section(sec_formatter.format("EVALUATING"), "cyan")
evaluator = evaluator_klass(**evaluator_kwargs)
avg_scores.update(evaluate_until_completed(evaluator, max_iter=max_iter))
if not generate_only:
assert len(evaluator_klasses) == len(evaluator_kwargses)
for evaluator_klass, evaluator_kwargs in zip(evaluator_klasses, evaluator_kwargses):
print_section(sec_formatter.format("EVALUATING"), "cyan")
evaluator = evaluator_klass(**evaluator_kwargs)
avg_scores.update(evaluate_until_completed(evaluator, max_iter=max_iter))
return avg_scores

0 comments on commit fced8c3

Please sign in to comment.