Skip to content

Commit

Permalink
Keep model state
Browse files Browse the repository at this point in the history
  • Loading branch information
santiag0m committed Mar 25, 2022
1 parent 17e5aac commit 0ccb45a
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/utils/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def reset(self):
def compute_bias(model: nn.Module, dataloader: DataLoader) -> torch.Tensor:
target_avg = RunningAvg()
pred_avg = RunningAvg()
is_training = model.training
model.eval()
device = next(model.parameters()).device
with torch.no_grad():
Expand All @@ -39,4 +40,6 @@ def compute_bias(model: nn.Module, dataloader: DataLoader) -> torch.Tensor:
target_avg.update(targets)
pred_avg.update(preds)
bias = target_avg.value - pred_avg.value
if is_training:
model.train()
return bias

0 comments on commit 0ccb45a

Please sign in to comment.