diff --git a/functorch/csrc/VmapGeneratedPlumbing.h b/functorch/csrc/VmapGeneratedPlumbing.h index 9f2a66f5c..2623f6552 100644 --- a/functorch/csrc/VmapGeneratedPlumbing.h +++ b/functorch/csrc/VmapGeneratedPlumbing.h @@ -6042,13 +6042,13 @@ at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); } template -::std::tuple grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { +::std::tuple grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, std::array output_mask) { c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); auto maybe_layer = maybeCurrentDynamicLayer(); TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { - return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners); + return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); } Tensor grad_output_value; optional grad_output_bdim; @@ -6059,7 +6059,7 @@ ::std::tuple grid_sampler_3d_backward_generated_plumbing( Tensor grid_value; optional grid_bdim; std::tie(grid_value, grid_bdim) = unwrapTensorAtLevel(grid, cur_level); - auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask); return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); } template