diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ad07d401ea74..786d50234ca2 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -245,9 +245,17 @@ void CodeGenC::PrintVecStore(const Variable* buffer, stream << ref << " = " << value << ";\n"; } +std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) { + if (from == target) return value; + std::ostringstream os; + os << "(("; + this->PrintType(target, os); + os << ")" << value << ")"; + return os.str(); +} + void CodeGenC::BindThreadIndex(const IterVar& iv) { - CHECK(!var_idmap_.count(iv->var.get())); - var_idmap_[iv->var.get()] = iv->thread_tag; + LOG(FATAL) << "not implemented"; } void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*) diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 510f6ddc9070..f311ddb2c0af 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -150,6 +150,8 @@ class CodeGenC : // print reference to a buffer as type t in index. std::string GetBufferRef( Type t, const Variable* buffer, Expr index); + // Get a cast type from to + std::string CastFromTo(std::string value, Type from, Type target); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 9847ba4661a6..5bb689df996d 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -35,6 +35,12 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) { CodeGenC::VisitStmt_(op); } +void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { + CHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, UInt(32), iv->var.type()); +} + void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 6955b5bf77d0..7974e2c0dbb8 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -30,7 +30,7 @@ class CodeGenCUDA final : public CodeGenC { const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) void PrintVecElemStore( const std::string& vec, Type t, int i, const std::string& value) final; - + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const Evaluate *op) final; diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 7f8c1dd9476d..ee360c8ece50 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -126,6 +126,12 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { this->stream << "}\n\n"; } +void CodeGenMetal::BindThreadIndex(const IterVar& iv) { + CHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, UInt(16), iv->var.type()); +} + void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h index 7331670d47c7..ebf0b9ad8319 100644 --- a/src/codegen/codegen_metal.h +++ b/src/codegen/codegen_metal.h @@ -24,7 +24,7 @@ class CodeGenMetal final : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) - + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) }; diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index f2ad3fe55b58..dd18b3a060a0 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -35,7 +35,8 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = os.str(); + var_idmap_[iv->var.get()] = + CastFromTo(os.str(), UInt(64), iv->var.type()); } void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index af855cef060d..1b84555cc855 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") ->GetTimeEvaluator(args[1], ctx, args[4]); } else { *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[3]); + m.GetFunction(args[1], false), ctx, args[4]); } });