From 42f0a7f5fb63b0d7d0775c72578c845a900baf0d Mon Sep 17 00:00:00 2001 From: Andrey Kharitonchik Date: Thu, 3 Jan 2019 10:35:57 -0800 Subject: [PATCH] tests: benchdnn: fixed number of operations count for deconv --- tests/benchdnn/conv/bench_deconv.cpp | 2 +- tests/benchdnn/conv/conv_aux.cpp | 18 ++++++++++++------ tests/benchdnn/conv/conv_common.hpp | 5 +++-- tests/benchdnn/conv/deconv.cpp | 2 +- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/benchdnn/conv/bench_deconv.cpp b/tests/benchdnn/conv/bench_deconv.cpp index e38b48336f4..937d50e4061 100644 --- a/tests/benchdnn/conv/bench_deconv.cpp +++ b/tests/benchdnn/conv/bench_deconv.cpp @@ -53,7 +53,7 @@ void reset_parameters() { } void check_correctness(const desc_t *c) { - const prb_t p(*c, dir, cfg, alg, attr, mb); + const prb_t p(*c, dir, cfg, alg, attr, mb, true); char pstr[max_prb_len]; prb2str(&p, pstr); diff --git a/tests/benchdnn/conv/conv_aux.cpp b/tests/benchdnn/conv/conv_aux.cpp index 395a8199f1b..ae7e9e37117 100644 --- a/tests/benchdnn/conv/conv_aux.cpp +++ b/tests/benchdnn/conv/conv_aux.cpp @@ -226,19 +226,25 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) { void prb_t::count_ops() { if (ops > 0) return; + int od_t = is_deconv ? this->id : this->od; + int oh_t = is_deconv ? this->ih : this->oh; + int ow_t = is_deconv ? this->iw : this->ow; + int id_t = is_deconv ? this->od : this->id; + int ih_t = is_deconv ? this->oh : this->ih; + int iw_t = is_deconv ? this->ow : this->iw; double sp_ops = 0; - for (int od = 0; od < this->od; ++od) { - for (int oh = 0; oh < this->oh; ++oh) { - for (int ow = 0; ow < this->ow; ++ow) { + for (int od = 0; od < od_t; ++od) { + for (int oh = 0; oh < oh_t; ++oh) { + for (int ow = 0; ow < ow_t; ++ow) { for (int kd = 0; kd < this->kd; ++kd) { const int id = od * this->sd - this->pd + kd * (this->dd + 1); - if (id < 0 || id >= this->id) continue; + if (id < 0 || id >= id_t) continue; for (int kh = 0; kh < this->kh; ++kh) { const int ih = oh * this->sh - this->ph + kh * (this->dh + 1); - if (ih < 0 || ih >= this->ih) continue; + if (ih < 0 || ih >= ih_t) continue; for (int kw = 0; kw < this->kw; ++kw) { const int iw = ow * this->sw - this->pw + kw * (this->dw + 1); - if (iw < 0 || iw >= this->iw) continue; + if (iw < 0 || iw >= iw_t) continue; sp_ops += 1; } } diff --git a/tests/benchdnn/conv/conv_common.hpp b/tests/benchdnn/conv/conv_common.hpp index 6319a8096b2..2d6a08528c6 100644 --- a/tests/benchdnn/conv/conv_common.hpp +++ b/tests/benchdnn/conv/conv_common.hpp @@ -104,9 +104,9 @@ const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg); struct prb_t: public desc_t { prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg, alg_t alg, - const attr_t &attr, int mb = 0) + const attr_t &attr, int mb = 0, bool is_deconv = false) : desc_t(desc), dir(dir), cfg(cfg), alg(alg), attr(attr) - , ops(0), scales(NULL) { + , ops(0), scales(NULL), is_deconv(is_deconv) { if (mb) this->mb = mb; count_ops(); generate_oscales(); @@ -117,6 +117,7 @@ struct prb_t: public desc_t { const dt_conf_t *cfg; alg_t alg; attr_t attr; + bool is_deconv; double ops; float *scales; diff --git a/tests/benchdnn/conv/deconv.cpp b/tests/benchdnn/conv/deconv.cpp index 9ab14dae9ad..66637973477 100644 --- a/tests/benchdnn/conv/deconv.cpp +++ b/tests/benchdnn/conv/deconv.cpp @@ -180,7 +180,7 @@ int doit(const prb_t *p, res_t *r) { *r = res_zero; bool with_groups = 1; - prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb); + prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true); swap(p_tr.ic, p_tr.oc); swap(p_tr.ih, p_tr.oh); swap(p_tr.id, p_tr.od);