Skip to content

Commit

Permalink
[tests] Add tests for ONNX multidirectional broadcasting in arithmeti…
Browse files Browse the repository at this point in the history
…c operators (op7)
  • Loading branch information
tlepley-cadence authored and bertmaher committed Dec 4, 2018
1 parent 712913b commit 9702b7d
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/models/onnxModels/addMultiBroadcastOp7.onnxtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
ir_version: 3
producer_name: "onnx-arith-broadcast"
opset_import {
version: 7
}

graph {
node {
input: "data"
input: "const"
output: "out"
name: "op"
op_type: "Add"
}
name: "test-model"
initializer {
dims: 4
dims: 1
data_type: FLOAT
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
name: "const"
}
input {
name: "data"
type {
tensor_type {
elem_type: FLOAT
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "out"
}
}
52 changes: 52 additions & 0 deletions tests/models/onnxModels/divMultiBroadcastOp7.onnxtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
ir_version: 3
producer_name: "onnx-arith-broadcast"
opset_import {
version: 7
}

graph {
node {
input: "data"
input: "const"
output: "out"
name: "op"
op_type: "Div"
}
name: "test-model"
initializer {
dims: 4
dims: 1
data_type: FLOAT
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
name: "const"
}
input {
name: "data"
type {
tensor_type {
elem_type: FLOAT
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "out"
}
}

51 changes: 51 additions & 0 deletions tests/models/onnxModels/mulMultiBroadcastOp7.onnxtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
ir_version: 3
producer_name: "onnx-arith-broadcast"
opset_import {
version: 7
}

graph {
node {
input: "data"
input: "const"
output: "out"
name: "op"
op_type: "Mul"
}
name: "test-model"
initializer {
dims: 4
dims: 1
data_type: FLOAT
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
name: "const"
}
input {
name: "data"
type {
tensor_type {
elem_type: FLOAT
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "out"
}
}
51 changes: 51 additions & 0 deletions tests/models/onnxModels/subMultiBroadcastOp7.onnxtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
ir_version: 3
producer_name: "onnx-arith-broadcast"
opset_import {
version: 7
}

graph {
node {
input: "data"
input: "const"
output: "out"
name: "op"
op_type: "Sub"
}
name: "test-model"
initializer {
dims: 4
dims: 1
data_type: FLOAT
float_data: 2.0
float_data: 2.0
float_data: 2.0
float_data: 2.0
name: "const"
}
input {
name: "data"
type {
tensor_type {
elem_type: FLOAT
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "out"
}
}
85 changes: 85 additions & 0 deletions tests/unittests/onnxImporterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,91 @@

using namespace glow;

template <class OpType>
static void
importArithMultiBroadcastTest(std::string fileName,
const std::function<float(float, float)> &op) {
ExecutionEngine EE{BackendKind::Interpreter};
auto &mod = EE.getModule();
Function *F = mod.createFunction("main");

std::string NetFilename = std::string("tests/models/onnxModels/") + fileName;
Context ctx;
Placeholder *graphOutputVar;
// Destroy the loader after the graph is loaded since the following execution
// will not depend on anyting from the loader.
{
Tensor data;
getNCHWData(&data, 1, 3, 1, 2);
ONNXModelLoader onnxLD(NetFilename, {"data"}, {&data.getType()}, *F);
graphOutputVar = EXIT_ON_ERR(onnxLD.getSingleOutput());
ctx.allocate(mod.getPlaceholders());
updateInputPlaceholdersByName(ctx, &mod, {"data"}, {&data});
}

// ONNX importer loads an arithmetic node and inserts:
// - a Reshape node for each broadcasted operand
// - a Tile node for each boardcasted dimension
// Check the graph structure
auto *saveNode = getSaveNodeFromDest(graphOutputVar);
auto *node = saveNode->getInput().getNode();
auto *opNode = llvm::dyn_cast<OpType>(node);
EXPECT_NE(nullptr, opNode);

// Left operand (1 dimension to broadcast)
auto *lhsTileNode = llvm::dyn_cast<TileNode>(opNode->getLHS().getNode());
EXPECT_NE(nullptr, lhsTileNode);
auto *lhsReshape =
llvm::dyn_cast<ReshapeNode>(lhsTileNode->getInput().getNode());
EXPECT_NE(nullptr, lhsReshape);

// Right operand (2 dimensions to broadcast)
auto *rhsNode = opNode->getRHS().getNode();
EXPECT_NE(nullptr, rhsNode);
auto *rhsTileNode = llvm::dyn_cast<TileNode>(opNode->getRHS().getNode());
EXPECT_NE(nullptr, rhsTileNode);
auto *rhsTileNode2 =
llvm::dyn_cast<TileNode>(rhsTileNode->getInput().getNode());
EXPECT_NE(nullptr, rhsTileNode2);
auto *rhsReshape =
llvm::dyn_cast<ReshapeNode>(rhsTileNode2->getInput().getNode());
EXPECT_NE(nullptr, rhsReshape);

// Compile&run the graph, and check the output
EE.compile(CompilationMode::Infer, F);
EE.run(ctx);
auto result = ctx.get(graphOutputVar)->getHandle();
std::vector<size_t> expectedDims = {1, 3, 4, 2};
std::vector<float> expectedValues = {
op(0, 2), op(1, 2), op(0, 2), op(1, 2), op(0, 2), op(1, 2),
op(0, 2), op(1, 2), op(2, 2), op(3, 2), op(2, 2), op(3, 2),
op(2, 2), op(3, 2), op(2, 2), op(3, 2), op(4, 2), op(5, 2),
op(4, 2), op(5, 2), op(4, 2), op(5, 2), op(4, 2), op(5, 2)};
EXPECT_TRUE(result.dims().vec() == expectedDims);
for (size_t i = 0; i < result.getType().size(); i++)
EXPECT_FLOAT_EQ(result.raw(i), expectedValues[i]);
}

TEST(onnx, importAddMultiBroadcastOp7) {
importArithMultiBroadcastTest<AddNode>(
"addMultiBroadcastOp7.onnxtxt", [](float a, float b) { return a + b; });
}

TEST(onnx, importSubMultiBroadcastOp7) {
importArithMultiBroadcastTest<SubNode>(
"subMultiBroadcastOp7.onnxtxt", [](float a, float b) { return a - b; });
}

TEST(onnx, importMulMultiBroadcastOp7) {
importArithMultiBroadcastTest<MulNode>(
"mulMultiBroadcastOp7.onnxtxt", [](float a, float b) { return a * b; });
}

TEST(onnx, importDivMultiBroadcastOp7) {
importArithMultiBroadcastTest<DivNode>(
"divMultiBroadcastOp7.onnxtxt", [](float a, float b) { return a / b; });
}

/// Test loading conv op from a ONNX model.
/// The input is N*C*H*W (1*1*3*3), the kernels is {2, 2},
/// strides is {1, 1}, pads is {1, 1, 1, 1}, group is 1.
Expand Down

0 comments on commit 9702b7d

Please sign in to comment.