Skip to content

Commit

Permalink
Fix greater than op (PaddlePaddle#1329)
Browse files Browse the repository at this point in the history
* add repeat_interleave op

* change code format

* fix greater than op
  • Loading branch information
jiuyuedeyu156 authored Jul 15, 2024
1 parent 40fcf3f commit f30eb9f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
24 changes: 21 additions & 3 deletions paddle2onnx/mapper/tensor/greater_than.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,32 @@

namespace paddle2onnx {
REGISTER_MAPPER(greater_than, GreaterThanMapper)
int32_t GreaterThanMapper::GetMinOpsetVersion(bool verbose) {
// NHWC is not supported
auto x_info = GetInput("X");
auto y_info = GetInput("Y");

if (x_info[0].dtype == P2ODataType::BOOL || y_info[0].dtype == P2ODataType::BOOL) {
Logger(verbose, 9) << "While the type of input is (bool), " << RequireOpset(9) << std::endl;
return 9;
}
return 7;
}
void GreaterThanMapper::Opset7() {
auto x_info = GetInput("X");
auto y_info = GetInput("Y");
auto out_info = GetOutput("Out");


int out_dtype = 0;
std::vector<std::string> aligned_inputs =
helper_->DtypeAlignment({x_info[0], y_info[0]}, &out_dtype);
std::vector<std::string> aligned_inputs = helper_->DtypeAlignment({x_info[0], y_info[0]}, &out_dtype);

if (out_dtype == P2ODataType::BOOL){
std::string new_x_name = helper_->AutoCast(x_info[0].name, x_info[0].dtype, P2ODataType::INT32);
std::string new_y_name = helper_->AutoCast(y_info[0].name, y_info[0].dtype, P2ODataType::INT32);
helper_->MakeNode("Greater", {new_x_name, new_y_name}, {out_info[0].name});
return ;
}
if (out_dtype != P2ODataType::FP32 && out_dtype != P2ODataType::FP64 &&
helper_->GetOpsetVersion() < 11) {
aligned_inputs[0] =
Expand All @@ -33,7 +50,8 @@ void GreaterThanMapper::Opset7() {
helper_->AutoCast(aligned_inputs[1], out_dtype, P2ODataType::FP32);
}


helper_->MakeNode("Greater", aligned_inputs, {out_info[0].name});
}

} // namespace paddle2onnx
} // namespace paddle2onnx
1 change: 1 addition & 0 deletions paddle2onnx/mapper/tensor/greater_than.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class GreaterThanMapper : public Mapper {
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset7() override;
int32_t GetMinOpsetVersion(bool verbose);
};

} // namespace paddle2onnx
37 changes: 20 additions & 17 deletions tests/test_greater_than.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,68 +30,71 @@ def forward(self, inputs, inputs_):
forward
"""
x = paddle.greater_than(inputs, inputs_)
print(x)
return x


def test_greater_than_9():
def test_greater_than_7():
"""
api: paddle.greater_than
op version: 9
op version: 7
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'greater_than', [9])
obj = APIOnnx(op, 'greater_than', [7])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32')),
paddle.to_tensor(randtool("float", 0, 1, [3, 10]).astype('float32')))
obj.run()


def test_greater_than_10():


def test_greater_than_9():
"""
api: paddle.greater_than
op version: 9
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'greater_than', [10])
obj = APIOnnx(op, 'greater_than', [9])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32')),
paddle.to_tensor(randtool("float", 0, 1, [3, 10]).astype('float32')))
obj.run()


def test_greater_than_11():
def test_greater_than_9_bool():
"""
api: paddle.greater_than
op version: 11
op version: 9
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'greater_than', [11])
obj = APIOnnx(op, 'greater_than', [9])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32')),
paddle.to_tensor(randtool("float", 0, 1, [3, 10]).astype('float32')))
paddle.to_tensor([True, False, True, False, True], dtype=paddle.bool),
paddle.to_tensor([False], dtype=paddle.bool))
obj.run()
# x[0] = 1


def test_greater_than_12():
def test_greater_than_9_bool_matrix():
"""
api: paddle.greater_than
op version: 12
op version: 9
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'greater_than', [12])
obj = APIOnnx(op, 'greater_than', [9])
obj.set_input_data(
"input_data",
paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32')),
paddle.to_tensor(randtool("float", 0, 1, [3, 10]).astype('float32')))
paddle.to_tensor([[True, False, True, False, True],[False, False, True, True, True]], dtype=paddle.bool),
paddle.to_tensor([[False],[True]], dtype=paddle.bool))
obj.run()
# x[0] = 1

0 comments on commit f30eb9f

Please sign in to comment.