Skip to content

Commit

Permalink
tests: benchdnn: fixed number of operations count for deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito authored and Roman Dubtsov committed Jan 6, 2019
1 parent 3456dd6 commit 42f0a7f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/benchdnn/conv/bench_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void reset_parameters() {
}

void check_correctness(const desc_t *c) {
const prb_t p(*c, dir, cfg, alg, attr, mb);
const prb_t p(*c, dir, cfg, alg, attr, mb, true);
char pstr[max_prb_len];
prb2str(&p, pstr);

Expand Down
18 changes: 12 additions & 6 deletions tests/benchdnn/conv/conv_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,25 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) {
void prb_t::count_ops() {
if (ops > 0) return;

int od_t = is_deconv ? this->id : this->od;
int oh_t = is_deconv ? this->ih : this->oh;
int ow_t = is_deconv ? this->iw : this->ow;
int id_t = is_deconv ? this->od : this->id;
int ih_t = is_deconv ? this->oh : this->ih;
int iw_t = is_deconv ? this->ow : this->iw;
double sp_ops = 0;
for (int od = 0; od < this->od; ++od) {
for (int oh = 0; oh < this->oh; ++oh) {
for (int ow = 0; ow < this->ow; ++ow) {
for (int od = 0; od < od_t; ++od) {
for (int oh = 0; oh < oh_t; ++oh) {
for (int ow = 0; ow < ow_t; ++ow) {
for (int kd = 0; kd < this->kd; ++kd) {
const int id = od * this->sd - this->pd + kd * (this->dd + 1);
if (id < 0 || id >= this->id) continue;
if (id < 0 || id >= id_t) continue;
for (int kh = 0; kh < this->kh; ++kh) {
const int ih = oh * this->sh - this->ph + kh * (this->dh + 1);
if (ih < 0 || ih >= this->ih) continue;
if (ih < 0 || ih >= ih_t) continue;
for (int kw = 0; kw < this->kw; ++kw) {
const int iw = ow * this->sw - this->pw + kw * (this->dw + 1);
if (iw < 0 || iw >= this->iw) continue;
if (iw < 0 || iw >= iw_t) continue;
sp_ops += 1;
}
}
Expand Down
5 changes: 3 additions & 2 deletions tests/benchdnn/conv/conv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg);

struct prb_t: public desc_t {
prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg, alg_t alg,
const attr_t &attr, int mb = 0)
const attr_t &attr, int mb = 0, bool is_deconv = false)
: desc_t(desc), dir(dir), cfg(cfg), alg(alg), attr(attr)
, ops(0), scales(NULL) {
, ops(0), scales(NULL), is_deconv(is_deconv) {
if (mb) this->mb = mb;
count_ops();
generate_oscales();
Expand All @@ -117,6 +117,7 @@ struct prb_t: public desc_t {
const dt_conf_t *cfg;
alg_t alg;
attr_t attr;
bool is_deconv;

double ops;
float *scales;
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/conv/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ int doit(const prb_t *p, res_t *r) {
*r = res_zero;
bool with_groups = 1;

prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb);
prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true);
swap(p_tr.ic, p_tr.oc);
swap(p_tr.ih, p_tr.oh);
swap(p_tr.id, p_tr.od);
Expand Down

0 comments on commit 42f0a7f

Please sign in to comment.