Skip to content

Commit

Permalink
[jit] Fix _batch_norm_with_update shape function (pytorch#122430)
Browse files Browse the repository at this point in the history
Summary: We used `native_batch_norm`'s shape function before,
but the schemas are actually different. We need to create new
shape functions for `_batch_norm_with_update` specifically.

Test Plan:
buck2 test '@fbcode//mode/opt-tsan' fbcode//caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - TestShapeGraphLinting.Basic'

Reviewers: bdhirsh, davidberard98, eellison

Differential Revision: [D55211182](https://our.internmc.facebook.com/intern/diff/D55211182)
Pull Request resolved: pytorch#122430
Approved by: https://github.com/eellison, https://github.com/bdhirsh
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Mar 22, 2024
1 parent 23a6d74 commit 6e6891e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
14 changes: 13 additions & 1 deletion torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,18 @@ def native_batch_norm(input: List[int],
_1 = torch.append(out, elem)
return (out, _size, _size)
def _batch_norm_with_update(input: List[int],
weight: Optional[List[int]],
bias: Optional[List[int]],
running_mean: Optional[List[int]],
running_var: Optional[List[int]]) -> Tuple[List[int], List[int], List[int], List[int]]:
_size = [input[1]]
out = annotate(List[int], [])
for _0 in range(torch.len(input)):
elem = input[_0]
_1 = torch.append(out, elem)
return (out, _size, _size, [0])
)=====")
+ std::string(R"=====(def cross_entropy_loss(self: List[int],
target: List[int],
Expand Down Expand Up @@ -3312,7 +3324,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "_batch_norm_with_update"},
{"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
{"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
Expand Down
13 changes: 12 additions & 1 deletion torch/jit/_shape_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,17 @@ def native_batch_norm(
return _copy(input), _size, _size


def _batch_norm_with_update(
input: List[int],
weight: Optional[List[int]],
bias: Optional[List[int]],
running_mean: Optional[List[int]],
running_var: Optional[List[int]],
) -> Tuple[List[int], List[int], List[int], List[int]]:
_size = [input[1]]
return _copy(input), _size, _size, [0]


def cross_entropy_loss(
self: List[int],
target: List[int],
Expand Down Expand Up @@ -1432,7 +1443,7 @@ def add_bounded_compute_mapping(
)
add_shape_compute_mapping(
"_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
native_batch_norm,
_batch_norm_with_update,
)

add_shape_compute_mapping(
Expand Down

0 comments on commit 6e6891e

Please sign in to comment.