Skip to content

Commit

Permalink
fix num frames counting
Browse files Browse the repository at this point in the history
  • Loading branch information
rizar committed Feb 19, 2019
1 parent 62e4977 commit b58c756
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions babyai/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def run_epoch_recurrence(self, demos, is_training=False):
indices = np.random.choice(len(demos), self.args.epoch_length)
if is_training:
np.random.shuffle(indices)

batch_size = min(self.args.batch_size, len(demos))
offset = 0

Expand All @@ -161,6 +162,7 @@ def run_epoch_recurrence(self, demos, is_training=False):
log["entropy"].append(_log["entropy"])
log["policy_loss"].append(_log["policy_loss"])
log["accuracy"].append(_log["accuracy"])
log["frames"] = frames

offset += batch_size

Expand Down Expand Up @@ -340,8 +342,7 @@ def initial_status():
self.scheduler.step()

log = self.run_epoch_recurrence(train_demos, is_training=True)
total_len = sum([len(item[3]) for item in train_demos])
status['num_frames'] += total_len
status['num_frames'] += log['frames']

update_end_time = time.time()

Expand Down

0 comments on commit b58c756

Please sign in to comment.