Skip to content

Commit

Permalink
Allow strided layout in torch.normal (pytorch#111205)
Browse files Browse the repository at this point in the history
Fixes pytorch#111119

Pull Request resolved: pytorch#111205
Approved by: https://github.com/ezyang
  • Loading branch information
lezcano authored and pytorchmergebot committed Oct 13, 2023
1 parent b1db959 commit 2fd546a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5927,7 +5927,7 @@ def log_normal(self, mean=1, std=2, generator=None):
def normal(
mean=0,
std=1,
shape=None,
size=None,
*,
generator=None,
dtype=None,
Expand All @@ -5936,37 +5936,37 @@ def normal(
pin_memory=None,
):
assert generator is None
assert layout is None
assert layout is None or layout == torch.strided

if not isinstance(std, TensorLike):
torch._check(
std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
)

if shape is None:
if size is None:
tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
torch._check(
len(tensors) > 0,
lambda: "normal expects that either mean or std is a tensor, or shape is defined",
lambda: "normal expects that either mean or std is a tensor, or size is defined",
)
torch._check(
layout is None and pin_memory is None,
lambda: "Cannot pass layout, or pin_memory without shape",
lambda: "Cannot pass layout, or pin_memory without size",
)

shape = _broadcast_shapes(*(t.shape for t in tensors))
size = _broadcast_shapes(*(t.shape for t in tensors))
dtype = tensors[0].dtype
device = tensors[0].device
else:
torch._check(
not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
lambda: "normal expects mean and std to be scalars when shape is defined",
lambda: "normal expects mean and std to be scalars when size is defined",
)
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device("cpu") if device is None else device

normal_samples = prims.normal(
shape,
size,
mean=0.0,
std=1.0,
dtype=dtype,
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,7 @@ def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwar

def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device)
yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device)
yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad))

def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
Expand Down

0 comments on commit 2fd546a

Please sign in to comment.