diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py b/python/tvm/topi/cuda/conv2d_transpose_nchw.py index 46ee685a7d1a..915e6cdecae2 100644 --- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py +++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py @@ -96,7 +96,7 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_ tvm.tir.indexdiv(y - pad_top, stride_height), tvm.tir.indexdiv(x - pad_left, stride_width), ], - tvm.tir.const(0.0, "float32"), + tvm.tir.const(0.0, data.dtype), ), name="data_pad", )