Skip to content

Commit

Permalink
Add back lrn test (#8134)
Browse files Browse the repository at this point in the history
* Revert "Skip OnnxBackendNodeModelTest::test_lrn_default_cuda that causes segfault (#8127)"

This reverts commit 410191c.

* Fix mismatched default values
  • Loading branch information
bddppq authored Jun 4, 2018
1 parent 94e197c commit ec4a0f3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
19 changes: 18 additions & 1 deletion caffe2/onnx/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ Caffe2Backend::get_special_operators() const {
{"Reciprocal", &Caffe2Backend::CreateReciprocal},
{"BatchNormalization", &Caffe2Backend::CreateBatchNormalization},
{"MatMul", &Caffe2Backend::CreateMatMul},
{"Upsample", &Caffe2Backend::CreateUpsample}};
{"Upsample", &Caffe2Backend::CreateUpsample},
{"LRN", &Caffe2Backend::CreateLRN}};
return kSpecialOperators;
}

Expand Down Expand Up @@ -905,6 +906,22 @@ Caffe2Ops Caffe2Backend::CreateUpsample(OnnxNode* onnx_node, int opset_version)
return c2_op;
}

Caffe2Ops Caffe2Backend::CreateLRN(OnnxNode* onnx_node, int opset_version) {
auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
const auto& attributes = onnx_node->attributes;
if (!attributes.HasAttribute("alpha")) {
auto* arg = c2_op.ops.Mutable(0)->add_arg();
arg->set_name("alpha");
arg->set_f(1e-4);
}
if (!attributes.HasAttribute("beta")) {
auto* arg = c2_op.ops.Mutable(0)->add_arg();
arg->set_name("beta");
arg->set_f(0.75);
}
return c2_op;
}

//==============================================
// Rest of the member functions for Caffe2Backend
//==============================================
Expand Down
2 changes: 2 additions & 0 deletions caffe2/onnx/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class Caffe2Backend {

Caffe2Ops CreateUpsample(OnnxNode* onnx_node, int opset_version);

Caffe2Ops CreateLRN(OnnxNode* onnx_node, int opset_version);


// LUT related getters
const std::unordered_map<std::string, std::string>& get_renamed_operators()
Expand Down
3 changes: 1 addition & 2 deletions caffe2/python/onnx/tests/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@
'|test_rnn_seq_length'
'|test_operator_add.*_cuda'
'|test_operator_lstm_cuda'
'|test_operator_rnn.*_cuda'
'|test_lrn_default_cuda)')
'|test_operator_rnn.*_cuda)')

# Temporarily skip some ONNX backend tests with broadcasting.
backend_test.exclude('(test_xor_bcast'
Expand Down

0 comments on commit ec4a0f3

Please sign in to comment.