Skip to content

Commit

Permalink
Updated plumbing (manually)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed Mar 1, 2022
1 parent 736cccd commit 8acf2d1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions functorch/csrc/VmapGeneratedPlumbing.h
Original file line number Diff line number Diff line change
Expand Up @@ -6042,13 +6042,13 @@ at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}
template <typename batch_rule_t, batch_rule_t batch_rule>
::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, std::array<bool, 2> output_mask) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) {
return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask);
}
Tensor grad_output_value;
optional<int64_t> grad_output_bdim;
Expand All @@ -6059,7 +6059,7 @@ ::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(
Tensor grid_value;
optional<int64_t> grid_bdim;
std::tie(grid_value, grid_bdim) = unwrapTensorAtLevel(grid, cur_level);
auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners);
auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask);
return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level));
}
template <typename batch_rule_t, batch_rule_t batch_rule>
Expand Down

0 comments on commit 8acf2d1

Please sign in to comment.