Skip to content

Commit

Permalink
[SR] Fix aten::split schema (pytorch#68135)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#68135

Update the schema to reflect the changes in  D31935573 (pytorch@6b44e75).

Test Plan:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Confirmed native implementation is used.

Reviewed By: hlu1

Differential Revision: D32326865

fbshipit-source-id: 7f607f57ceb6690a2782d94d9ee736ba64e7d242
  • Loading branch information
Mike Iovine authored and facebook-github-bot committed Nov 11, 2021
1 parent 47bc47f commit 1f07efd
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions torch/csrc/jit/runtime/static/native_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,23 +541,20 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
};
});

REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::split,
aten_split,
[](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::split(Tensor self, int split_size, int dim=0) -> Tensor[]"))) {
LogAndDumpSchema(n);
return nullptr;
}
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) {
LogAndDumpSchema(n);
return nullptr;
}

return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto split_size = p_node->Input(1).toInt();
const auto dim = p_node->Input(2).toInt();
p_node->Output(0) = at::native::split(self, split_size, dim);
};
});
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto split_size = p_node->Input(1).toInt();
const auto dim = p_node->Input(2).toInt();
p_node->Output(0) = at::native::split(self, split_size, dim);
};
});

} // namespace jit
} // namespace torch

0 comments on commit 1f07efd

Please sign in to comment.