Skip to content

Commit

Permalink
Allow datatypes besides fp32 in conv2d_transpose for cuda. (apache#6593)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored Oct 1, 2020
1 parent e78aa61 commit e31564e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_
tvm.tir.indexdiv(y - pad_top, stride_height),
tvm.tir.indexdiv(x - pad_left, stride_width),
],
tvm.tir.const(0.0, "float32"),
tvm.tir.const(0.0, data.dtype),
),
name="data_pad",
)
Expand Down

0 comments on commit e31564e

Please sign in to comment.