Skip to content

Commit a70082a

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[functorch] Move cond.py to _cond.py and expose cond() under functorch.experimental.control_flow. (pytorch#89819)
Summary: Similar to pytorch#88767 we want to reduce the chance that users accidentally import private functions from `functorch.experimental.cond` as if they were public interfaces. We also move `cond()` under `control_flow.py` to stay consistent with `map()` op. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#89819 Approved by: https://github.com/zou3519
1 parent d1760d7 commit a70082a

File tree

5 files changed

+7
-8
lines changed

5 files changed

+7
-8
lines changed

functorch/experimental/cond.py functorch/experimental/_cond.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# TODO(zhxchen17) Expose API through functorhc.experimental.control_flow
2-
# and rename this file to _cond.py.
31
import torch
42

53
import torch.utils._pytree as pytree
+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from ._map import map # noqa: F401
2+
from ._cond import cond # noqa: F401

test/dynamo/test_export.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ def nop(x):
14341434

14351435
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
14361436
def test_export_with_module_layer(self):
1437-
from functorch.experimental.cond import cond
1437+
from functorch.experimental.control_flow import cond
14381438

14391439
def true_fn(layer, val):
14401440
return layer(val) * torch.tensor(2)

test/dynamo/test_misc.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2334,7 +2334,7 @@ def f_onnx(x):
23342334
self.assertEqual(f_onnx(input_two_dims), 8)
23352335

23362336
def test_cond(self):
2337-
from functorch.experimental.cond import cond
2337+
from functorch.experimental.control_flow import cond
23382338

23392339
def true_fn(x):
23402340
return x.sin()
@@ -2352,7 +2352,7 @@ def f(pred, x):
23522352
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
23532353

23542354
def test_cond_nested(self):
2355-
from functorch.experimental.cond import cond
2355+
from functorch.experimental.control_flow import cond
23562356

23572357
def true_fn_nested(x):
23582358
return x * 10
@@ -2397,7 +2397,7 @@ def f(pred, pred2, x):
23972397
self.assertTrue(cc.frame_count, 2)
23982398

23992399
def test_cond_export(self):
2400-
from functorch.experimental.cond import cond
2400+
from functorch.experimental.control_flow import cond
24012401

24022402
def true_fn_nested(x):
24032403
return x * 10
@@ -2442,7 +2442,7 @@ def f(pred, pred2, x):
24422442
) # * -1 then add x
24432443

24442444
def test_cond_export_single_arg(self):
2445-
from functorch.experimental.cond import cond
2445+
from functorch.experimental.control_flow import cond
24462446

24472447
def true_fn(x):
24482448
return x

test/functorch/test_control_flow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Owner(s): ["module: functorch"]
22
import torch
3-
from functorch.experimental.cond import cond
43
from functorch.experimental import control_flow
4+
from functorch.experimental.control_flow import cond
55
from torch.fx.experimental.proxy_tensor import make_fx
66

77
from torch.testing._internal.common_utils import run_tests, TestCase

0 commit comments

Comments
 (0)