Skip to content

Commit

Permalink
Convert tensors in stats dict into scalers (CarperAI#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHAOTING authored Apr 4, 2023
1 parent fdf4c16 commit d3d1cef
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def evaluate(self): # noqa: C901
stats["time/metric"] = time() - metric_time

mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1) for k, xs in metrics.items()
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items()
}

stats.update(mean_metrics)
Expand Down
10 changes: 5 additions & 5 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
if self.ref_mean is None:
self.ref_mean, self.ref_std = scores.mean(), scores.std()
all_scores_mean, all_scores_std = self.running_moments.update(scores)
stats["exp_scores/mean"] = all_scores_mean
stats["exp_scores/std"] = all_scores_std
stats["exp_scores/running_mean"] = self.running_moments.mean
stats["exp_scores/running_std"] = self.running_moments.std
stats["exp_scores/mean"] = all_scores_mean.item()
stats["exp_scores/std"] = all_scores_std.item()
stats["exp_scores/running_mean"] = self.running_moments.mean.item()
stats["exp_scores/running_std"] = self.running_moments.std.item()

if self.config.method.scale_reward == "running":
scores /= self.running_moments.std
Expand Down Expand Up @@ -479,7 +479,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
if torch.distributed.is_initialized():
torch.distributed.all_reduce(self.mean_kl, torch.distributed.ReduceOp.AVG)

stats["policy/sqrt_kl"] = torch.sqrt(self.mean_kl)
stats["policy/sqrt_kl"] = torch.sqrt(self.mean_kl).item()
stats["kl_ctl_value"] = self.kl_ctl.value
stats["time/exp"] = exp_time

Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def loss(self, batch):
labels[~batch.attention_mask.bool()] = -100

loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss
stats = {"loss": loss}
stats = {"loss": loss.item()}

return loss, stats

Expand Down

0 comments on commit d3d1cef

Please sign in to comment.