Skip to content

Commit

Permalink
cpu: gtests: add test for int16->int32 fwd+relu convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
nshustrov committed Jun 22, 2017
1 parent 999b051 commit 77b8ed1
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 39 deletions.
3 changes: 2 additions & 1 deletion tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ file(GLOB PRIM_TEST_CASES_SRC
test_convolution_format_any.cpp
test_convolution_forward_f32.cpp
test_convolution_forward_s16s16s32.cpp
test_convolution_relu_forward.cpp
test_convolution_backward_data.cpp
test_convolution_relu_forward_f32.cpp
test_convolution_relu_forward_s16s16s32.cpp
test_convolution_backward_weights.cpp
) #temporary

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@

namespace mkldnn {

template <typename data_t>
template <typename data_t_src, typename data_t_wei,
typename data_t_acc, typename data_t_dst>
void compute_ref_conv_relu_fwd(const test_convolution_sizes_t &c,
const memory &src, const memory &weights, const memory &bias,
const memory &dst, bool w_bias)
{
data_t *src_data = (data_t *)src.get_data_handle();
data_t *weights_data = (data_t *)weights.get_data_handle();
data_t *bias_data
= (data_t *)(w_bias ? bias.get_data_handle() : nullptr);
data_t *dst_data = (data_t *)dst.get_data_handle();
data_t_src *src_data = (data_t_src *)src.get_data_handle();
data_t_wei *weights_data = (data_t_wei *)weights.get_data_handle();
data_t_dst *bias_data
= (data_t_dst *)(w_bias ? bias.get_data_handle() : nullptr);
data_t_dst *dst_data = (data_t_dst *)dst.get_data_handle();

const memory::desc src_d = src.get_primitive_desc().desc();
const memory::desc weights_d = weights.get_primitive_desc().desc();
Expand Down Expand Up @@ -89,7 +90,8 @@ void compute_ref_conv_relu_fwd(const test_convolution_sizes_t &c,
}
}

template <typename data_t>
template <typename data_t_src, typename data_t_wei,
typename data_t_acc, typename data_t_dst>
class convolution_relu_test
: public ::testing::TestWithParam<test_convolution_params_t> {
protected:
Expand All @@ -102,51 +104,53 @@ class convolution_relu_test
ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
ASSERT_EQ(p.aalgorithm, convolution_direct);
auto eng = engine(p.engine_kind, 0);
memory::data_type data_type = data_traits<data_t>::data_type;
ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);

memory::data_type data_type_src = data_traits<data_t_src>::data_type;
memory::data_type data_type_dst = data_traits<data_t_dst>::data_type;
memory::data_type data_type_wei = data_traits<data_t_wei>::data_type;

test_convolution_sizes_t cd = p.sizes;

auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw },
data_type, p.formats.src_format);
data_type_src, p.formats.src_format);
auto c_weights_desc = cd.ng > 1 ?
create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw },
data_type, p.formats.weights_format) :
data_type_wei, p.formats.weights_format) :
create_md({ cd.oc, cd.ic, cd.kh, cd.kw },
data_type, p.formats.weights_format);
data_type_wei, p.formats.weights_format);
auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow },
data_type, p.formats.dst_format);
data_type_dst, p.formats.dst_format);

auto c_src = memory({c_src_desc, eng});
auto c_weights = memory({c_weights_desc, eng});
auto c_dst = memory({c_dst_desc, eng});

auto dst_ref = memory({c_dst_desc, eng});

fill_data<data_t>(c_src.get_primitive_desc().get_size()
/ sizeof(data_t), (data_t *)c_src.get_data_handle());
fill_data<data_t_src>(c_src.get_primitive_desc().get_size()
/ sizeof(data_t_src), (data_t_src *)c_src.get_data_handle());
// TODO: Temporary workaround for testing of convolution + relu
data_t *src_data = (data_t *)c_src.get_data_handle();
data_t_src *src_data = (data_t_src *)c_src.get_data_handle();
const int mb_chunk =
(c_src.get_primitive_desc().get_size() / sizeof(data_t))
(c_src.get_primitive_desc().get_size() / sizeof(data_t_src))
/ cd.mb;
for (int i = 0; i < cd.mb * mb_chunk; ++i) {
if ((i / mb_chunk) % 2) src_data[i] *= -1.;
}

fill_data<data_t>(
fill_data<data_t_wei>(
c_weights.get_primitive_desc().get_size()
/ sizeof(data_t), (data_t *)c_weights.get_data_handle());
/ sizeof(data_t_wei),(data_t_wei *)c_weights.get_data_handle());

bool with_bias = p.formats.bias_format != memory::format::format_undef;
auto c_bias_desc = with_bias ?
create_md({ cd.oc }, data_type, p.formats.bias_format) :
create_md({}, data_type, p.formats.bias_format);
create_md({ cd.oc }, data_type_dst, p.formats.bias_format) :
create_md({}, data_type_dst, p.formats.bias_format);
auto c_bias = memory({c_bias_desc, eng});
if (with_bias) {
fill_data<data_t>(
c_bias.get_primitive_desc().get_size() / sizeof(data_t),
(data_t *)c_bias.get_data_handle(), 1., true);
fill_data<data_t_dst>(
c_bias.get_primitive_desc().get_size() / sizeof(data_t_dst),
(data_t_dst *)c_bias.get_data_handle(), 1., true);
}

std::vector<int> padR = { cd.padh, cd.padw };
Expand Down Expand Up @@ -184,21 +188,10 @@ class convolution_relu_test

stream(stream::kind::lazy).submit(pipeline).wait();

compute_ref_conv_relu_fwd<data_t>(cd, c_src, c_weights, c_bias,
dst_ref, with_bias);
compare_data<data_t>(dst_ref, c_dst);
compute_ref_conv_relu_fwd<data_t_src, data_t_wei, data_t_wei,
data_t_dst>(cd, c_src, c_weights, c_bias, dst_ref, with_bias);
compare_data<data_t_dst>(dst_ref, c_dst);
}
};

using convolution_test = convolution_relu_test<float>;

TEST_P(convolution_test, TestConvolution)
{
}

#define FP32
#define DIRECTION_FORWARD
#include "convolution_common.h"
#include "diluted_convolution.h"

}
34 changes: 34 additions & 0 deletions tests/gtests/test_convolution_relu_forward_f32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright 2016-2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
+*******************************************************************************/

#include "mkldnn_test_common.hpp"
#include "gtest/gtest.h"

#include "mkldnn.hpp"
#include "test_convolution_relu_forward_common.hpp"
namespace mkldnn {

using convolution_test = convolution_relu_test<float, float, float, float>;

TEST_P(convolution_test, TestConvolution)
{
}

#define FP32
#define DIRECTION_FORWARD
#include "convolution_common.h"

}
35 changes: 35 additions & 0 deletions tests/gtests/test_convolution_relu_forward_s16s16s32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*******************************************************************************
* Copyright 2016-2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <stdint.h>
#include "mkldnn_test_common.hpp"
#include "gtest/gtest.h"

#include "mkldnn.hpp"
#include "test_convolution_relu_forward_common.hpp"
namespace mkldnn {

using convolution_test = convolution_relu_test<int16_t, int16_t,
int32_t, int32_t>;

TEST_P(convolution_test, TestConvolution)
{
}

#define S16S16S32
#define DIRECTION_FORWARD
#include "convolution_common.h"

}

0 comments on commit 77b8ed1

Please sign in to comment.