Skip to content

Commit

Permalink
Add early validation logic to dynamic_dim (pytorch#102982)
Browse files Browse the repository at this point in the history
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jun 8, 2023
1 parent f1f13a3 commit cea899c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ def f(x, y):
):
export(f, example_inputs, constraints)

def test_not_correct_dim(self):
def f(x):
return x.cos()

def g(x):
return x + 4

inp_for_f = torch.tensor(5)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Cannot mark 0-dimension tensors to be dynamic"):
constraints = [dynamic_dim(inp_for_f, 0)]

inp_for_f_mul_dim = torch.ones(5, 5)
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"Expected the dimension passed to dynamic_dim to be in the range \\[0:1\\]"
):
constraints = [dynamic_dim(inp_for_f_mul_dim, 2)]

inp_for_g = 4
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Expected tensor as input to dynamic_dim"):
constraints = [dynamic_dim(inp_for_g, 0)]

if __name__ == '__main__':
run_tests()
1 change: 1 addition & 0 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class UserErrorType(Enum):
ANTI_PATTERN = auto()
STANDARD_LIBRARY = auto()
CONSTRAIN_VIOLATION = auto()
DYNAMIC_DIM = auto()


class UserError(Unsupported):
Expand Down
19 changes: 19 additions & 0 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@
# ]
# )
def dynamic_dim(t: torch.Tensor, index: int):
if not isinstance(t, torch.Tensor):
raise UserError(
UserErrorType.DYNAMIC_DIM,
f"Expected tensor as input to dynamic_dim but got {type(t)}"
)

if t.dim() < 1:
raise UserError(
UserErrorType.DYNAMIC_DIM,
"Cannot mark 0-dimension tensors to be dynamic"
)

if index >= t.dim():
raise UserError(
UserErrorType.DYNAMIC_DIM,
f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]"
f" but got {index}, which is out of bounds for the given tensor."
)

return Constraint(
weakref.ref(t),
id(t),
Expand Down

0 comments on commit cea899c

Please sign in to comment.