Skip to content

Commit

Permalink
[Operator][New] Add einsum (PaddlePaddle#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng-Bicheng authored Jul 11, 2024
1 parent 36a79af commit 880c4a6
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 1 deletion.
47 changes: 47 additions & 0 deletions paddle2onnx/mapper/tensor/einsum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/einsum.h"

namespace paddle2onnx {
REGISTER_MAPPER(einsum, EinsumMapper)

int32_t EinsumMapper::GetMinOpset(bool verbose)
{
constexpr int op_version = 12;
Logger(verbose, op_version) << RequireOpset(op_version) << std::endl;
return op_version;
}

void EinsumMapper::Opset12() {
auto input_info = GetInput("Operands");
auto output_info = GetOutput("Out");
GetAttr("equation", &equation_);

std::vector<std::string> input_info_names;
for (size_t i = 0; i < input_info.size(); i++)
{
input_info_names.emplace_back(input_info[i].name);
}

std::vector<std::string> output_info_names;
for (size_t i = 0; i < output_info.size(); i++)
{
output_info_names.emplace_back(output_info[i].name);
}
auto node = helper_->MakeNode("Einsum", input_info_names, output_info_names);
AddAttribute(node, "equation", equation_);
}

} // namespace paddle2onnx
35 changes: 35 additions & 0 deletions paddle2onnx/mapper/tensor/einsum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class EinsumMapper : public Mapper {
public:
EinsumMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
int32_t GetMinOpset(bool verbose) override;
void Opset12() override;

private:
std::string equation_;
};

} // namespace paddle2onnx
2 changes: 1 addition & 1 deletion paddle2onnx/mapper/tensor/empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class EmptyMapper : public Mapper {
GetAttr("dtype", &output_dtype_);
}

int32_t GetMinOpset(bool verbose = false) override;
int32_t GetMinOpset(bool verbose) override;
void Opset11() override;
private:
int64_t output_dtype_;
Expand Down
233 changes: 233 additions & 0 deletions tests/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from onnxbase import APIOnnx
from onnxbase import randtool

def test_einsum_sum():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, input):
"""
forward
"""
x = paddle.einsum('i->', input)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_sum', [12])
obj.set_input_data("input_data", paddle.rand([4]))
obj.run()


def test_einsum_dot():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x, y):
"""
forward
"""
x = paddle.einsum("i,i->", x, y)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([4])
obj.set_input_data("input_data", input_x, input_x)
obj.run()


def test_einsum_outer():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x, y):
"""
forward
"""
x = paddle.einsum("i,j->ij", x, y)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([4])
input_y = paddle.rand([5])
obj.set_input_data("input_data", input_x, input_y)
obj.run()


def test_einsum_transpose():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x):
"""
forward
"""
x = paddle.einsum("ijk->kji", x)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([2, 3, 2])
obj.set_input_data("input_data", input_x)
obj.run()


def test_einsum_batch_matrix_multiplication():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x, y):
"""
forward
"""
x = paddle.einsum("ijk, ikl->ijl", x, y)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([2, 3, 2])
input_y = paddle.rand([2, 2, 3])
obj.set_input_data("input_data", input_x, input_y)
obj.run()


def test_einsum_ellipsis_transpose():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x):
"""
forward
"""
x = paddle.einsum("...jk->...kj", x)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([2, 3, 2])
obj.set_input_data("input_data", input_x)
obj.run()


def test_einsum_ellipsis_batch_matrix_multiplication():
"""
api: paddle.einsum
op version: 12
"""

class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, x, y):
"""
forward
"""
x = paddle.einsum("...jk, ...kl->...jl", x, y)
return x

op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'einsum_dot', [12])
input_x = paddle.rand([2, 3, 2])
input_y = paddle.rand([2, 2, 3])
obj.set_input_data("input_data", input_x, input_y)
obj.run()


if __name__ == "__main__":
test_einsum_sum()
test_einsum_dot()

0 comments on commit 880c4a6

Please sign in to comment.