Skip to content

Commit

Permalink
[CODEGEN] Concise typecast for threadIdx (apache#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jul 3, 2017
1 parent bf97724 commit b0e41b9
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 6 deletions.
12 changes: 10 additions & 2 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,17 @@ void CodeGenC::PrintVecStore(const Variable* buffer,
stream << ref << " = " << value << ";\n";
}

std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")" << value << ")";
return os.str();
}

void CodeGenC::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = iv->thread_tag;
LOG(FATAL) << "not implemented";
}

void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class CodeGenC :
// print reference to a buffer as type t in index.
std::string GetBufferRef(
Type t, const Variable* buffer, Expr index);
// Get a cast type from to
std::string CastFromTo(std::string value, Type from, Type target);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CodeGenC::VisitStmt_(op);
}

void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
}

void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;

void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const Evaluate *op) final;
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
this->stream << "}\n\n";
}

void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(16), iv->var.type());
}

void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CodeGenMetal final : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)

void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
};
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
} else {
os << "get_group_id(" << ts.dim_index << ")";
}
var_idmap_[iv->var.get()] = os.str();
var_idmap_[iv->var.get()] =
CastFromTo(os.str(), UInt(64), iv->var.type());
}

void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
->GetTimeEvaluator(args[1], ctx, args[4]);
} else {
*rv = WrapTimeEvaluator(
m.GetFunction(args[1], false), ctx, args[3]);
m.GetFunction(args[1], false), ctx, args[4]);
}
});

Expand Down

0 comments on commit b0e41b9

Please sign in to comment.