Skip to content

Commit

Permalink
boolean mode in module.train
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko authored and soumith committed Mar 2, 2017
1 parent f366e5f commit b5f7592
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,25 +357,22 @@ def modules(self, memo=None):
for m in module.modules(memo):
yield m

def train(self):
def train(self, mode=True):
"""Sets the module in training mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = True
self.training = mode
for module in self.children():
module.train()
module.train(mode)
return self

def eval(self):
"""Sets the module in evaluation mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = False
for module in self.children():
module.eval()
return self
return self.train(False)

def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
Expand Down

0 comments on commit b5f7592

Please sign in to comment.