From b58c7569dafe8bc9b0f81d7e638e8fba413eccc7 Mon Sep 17 00:00:00 2001 From: rizar Date: Thu, 14 Feb 2019 16:52:46 +0000 Subject: [PATCH] fix num frames counting --- babyai/imitation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/babyai/imitation.py b/babyai/imitation.py index 327c8594..313f69fc 100644 --- a/babyai/imitation.py +++ b/babyai/imitation.py @@ -139,6 +139,7 @@ def run_epoch_recurrence(self, demos, is_training=False): indices = np.random.choice(len(demos), self.args.epoch_length) if is_training: np.random.shuffle(indices) + batch_size = min(self.args.batch_size, len(demos)) offset = 0 @@ -161,6 +162,7 @@ def run_epoch_recurrence(self, demos, is_training=False): log["entropy"].append(_log["entropy"]) log["policy_loss"].append(_log["policy_loss"]) log["accuracy"].append(_log["accuracy"]) + log["frames"] = frames offset += batch_size @@ -340,8 +342,7 @@ def initial_status(): self.scheduler.step() log = self.run_epoch_recurrence(train_demos, is_training=True) - total_len = sum([len(item[3]) for item in train_demos]) - status['num_frames'] += total_len + status['num_frames'] += log['frames'] update_end_time = time.time()