Skip to content

Commit

Permalink
num examples param addition/fixes (facebookresearch#2496)
Browse files Browse the repository at this point in the history
* num examples param addition/fixes

* linto

* moar

* mora
  • Loading branch information
jaseweston authored Mar 21, 2020
1 parent 7399e58 commit 5469ac9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
10 changes: 8 additions & 2 deletions parlai/scripts/data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, False, 'Lint for ParlAI tasks')
# Get command line arguments
parser.add_argument('-n', '-ne', '--num-examples', type=int, default=-1)
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
parser.add_argument(
'--agent',
Expand Down Expand Up @@ -106,8 +107,13 @@ def keep_token(t):
return False
return True

# max number of examples to evaluate
max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
cnt = 0

# Show some example dialogs.
while not world.epoch_done():
while not world.epoch_done() and cnt < max_cnt:
cnt += opt.get('batchsize', 1)
world.parley()
act = world.get_acts()[opt.get('agent')]
for itype in {'input', 'labels'}:
Expand Down Expand Up @@ -172,4 +178,4 @@ def keep_token(t):
report_text, report_log = verify(
parser.parse_args(print_args=False), print_parser=parser
)
print(report_text)
print(report_text.replace('\\n', '\n'))
2 changes: 1 addition & 1 deletion parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _eval_single_world(opt, agent, task):
if log_time.time() > log_every_n_secs:
report = world.report()
text, report = log_time.log(
report.get('exs', 0), world.num_examples(), report
report.get('exs', 0), min(max_cnt, world.num_examples()), report
)
print(text)

Expand Down
7 changes: 5 additions & 2 deletions parlai/scripts/eval_wordstat.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def eval_wordstat(opt, print_parser=None):
log_time = TimeLogger()

cnt = 0
max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
word_statistics = {
'mean_wlength': [],
'mean_clength': [],
Expand Down Expand Up @@ -187,7 +188,9 @@ def process_prediction(prediction, word_statistics):

if log_time.time() > log_every_n_secs:
report = world.report()
text, report = log_time.log(report['exs'], world.num_examples(), report)
text, report = log_time.log(
report['exs'], min(max_cnt, world.num_examples()), report
)
print(text)
stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
stat_str += ', '.join(
Expand All @@ -214,7 +217,7 @@ def process_prediction(prediction, word_statistics):
prec=2,
)
)
if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
if cnt >= max_cnt:
break
if world.epoch_done():
print("EPOCH DONE")
Expand Down

0 comments on commit 5469ac9

Please sign in to comment.