Skip to content

Commit 515dff7

Browse files
samdowpytorchmergebot
samdow
authored andcommitted
[functorch] move batch_norm_replacement to torch.func (pytorch#91412)
Pull Request resolved: pytorch#91412 Approved by: https://github.com/zou3519
1 parent 7bdcf6d commit 515dff7

File tree

5 files changed

+6
-2
lines changed

5 files changed

+6
-2
lines changed

docs/source/func.api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Here's how we would compute the Jacobian over the parameters
6666

6767
functional_call
6868
stack_module_state
69+
replace_all_batch_norm_modules_
6970

7071
If you're looking for information on fixing Batch Norm modules, please follow the
7172
guidance here

docs/source/func.batch_norm.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ have a net where you want the BatchNorm to not use running stats, you can run
6969

7070
.. code-block:: python
7171
72-
from functorch.experimental import replace_all_batch_norm_modules_
72+
from torch.func import replace_all_batch_norm_modules_
7373
replace_all_batch_norm_modules_(net)
7474
7575
Option 4: eval mode

functorch/experimental/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# PyTorch forward-mode is not mature yet
22
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
33
from torch._functorch.vmap import chunk_vmap
4-
from .batch_norm_replacement import replace_all_batch_norm_modules_
4+
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
55
from functorch import functionalize

functorch/experimental/batch_norm_replacement.py torch/_functorch/batch_norm_replacement.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch.nn as nn
2+
from torch._functorch.utils import exposed_in
23

34

45
def batch_norm_without_running_stats(module: nn.Module):
@@ -9,6 +10,7 @@ def batch_norm_without_running_stats(module: nn.Module):
910
module.track_running_stats = False
1011

1112

13+
@exposed_in("torch.func")
1214
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
1315
"""
1416
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and

torch/func/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
functionalize,
1010
)
1111
from torch._functorch.functional_call import functional_call, stack_module_state
12+
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
1213
from torch._functorch.vmap import vmap

0 commit comments

Comments
 (0)