Skip to content

Commit

Permalink
[misc] Improve type checking error messages (taichi-dev#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Apr 16, 2020
1 parent b6b2ff2 commit 588e504
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 172 deletions.
6 changes: 3 additions & 3 deletions taichi/common/dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class Dict {
int64 ptr_ll;
std::getline(ss, t, '\t');
ss >> ptr_ll;
assert_info(t == typeid(T).name(),
"Pointer type mismatch: " + t + " and " + typeid(T).name());
TI_ASSERT_INFO(t == typeid(T).name(),
"Pointer type mismatch: " + t + " and " + typeid(T).name());
return reinterpret_cast<T *>(ptr_ll);
}

Expand Down Expand Up @@ -347,7 +347,7 @@ inline bool Dict::get<bool>(std::string key) const {
{"true", true}, {"True", true}, {"t", true}, {"1", true},
{"false", false}, {"False", false}, {"f", false}, {"0", false},
};
assert_info(dict.find(s) != dict.end(), "Unkown identifer for bool: " + s);
TI_ASSERT_INFO(dict.find(s) != dict.end(), "Unkown identifer for bool: " + s);
return dict[s];
}

Expand Down
237 changes: 121 additions & 116 deletions taichi/common/interface.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion taichi/common/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ TI_NAMESPACE_BEGIN
class Task : public Unit {
public:
virtual std::string run(const std::vector<std::string> &parameters) {
assert_info(parameters.size() == 0, "No parameters supported.");
TI_ASSERT_INFO(parameters.size() == 0, "No parameters supported.");
return this->run();
}

Expand Down
10 changes: 4 additions & 6 deletions taichi/common/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,15 @@ static_assert(__cplusplus >= 201402L, "C++14 required.");
#define DEBUG_TRIGGER
#endif

#define assert_info(x, info) \
#define TI_STATIC_ASSERT(x) static_assert((x), #x);
#define TI_ASSERT(x) TI_ASSERT_INFO((x), #x)
#define TI_ASSERT_INFO(x, ...) \
{ \
bool ___ret___ = static_cast<bool>(x); \
if (!___ret___) { \
TI_ERROR(info); \
TI_ERROR(__VA_ARGS__); \
} \
}

#define TI_STATIC_ASSERT(x) static_assert((x), #x);
#define TI_ASSERT(x) TI_ASSERT_INFO((x), #x)
#define TI_ASSERT_INFO assert_info
#define TI_NOT_IMPLEMENTED TI_ERROR("Not supported.");

#define TI_NAMESPACE_BEGIN namespace taichi {
Expand Down
3 changes: 2 additions & 1 deletion taichi/math/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ TI_FORCE_INLINE bool abnormal(T m) noexcept {
}

inline int64 get_largest_pot(int64 a) noexcept {
assert_info(a > 0, "a should be positive, instead of " + std::to_string(a));
TI_ASSERT_INFO(a > 0,
"a should be positive, instead of " + std::to_string(a));
// TODO: optimize
int64 i = 1;
while (i <= a / 2) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/python/export_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ void array2d_to_ndarray(Array2D<Vector3> *arr,
uint64 output) // 'output' is actually a pointer...
{
int width = arr->get_width(), height = arr->get_height();
assert_info(width > 0, "");
assert_info(height > 0, "");
TI_ASSERT(width > 0);
TI_ASSERT(height > 0);
for (auto &ind : arr->get_region()) {
for (int k = 0; k < channels; k++) {
const Vector3 entry = (*arr)[ind];
Expand Down
2 changes: 1 addition & 1 deletion taichi/system/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Profiler {
}

void stop() {
assert_info(!stopped, "Profiler already stopped.");
TI_ASSERT_INFO(!stopped, "Profiler already stopped.");
float64 elapsed = Time::get_time() - start_time;
if ((int64)elements != -1) {
ProfilerRecords::get_instance().insert_sample(elapsed, elements);
Expand Down
62 changes: 26 additions & 36 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ class TypeCheck : public IRVisitor {
}

void visit(AllocaStmt *stmt) {
// Do nothing.
// Alloca type is determined by the first LocalStore in IR visiting order,
// at compile time
// Do nothing. Alloca type is determined by the first LocalStore in IR
// visiting order, at compile time.

// ret_type stands for its element type.
stmt->ret_type.set_is_pointer(false);
Expand All @@ -40,10 +39,9 @@ class TypeCheck : public IRVisitor {

void visit(Block *stmt_list) {
std::vector<Stmt *> stmts;
// Make a copy since type casts may be inserted for type promotion
// Make a copy since type casts may be inserted for type promotion.
for (auto &stmt : stmt_list->statements) {
stmts.push_back(stmt.get());
// stmt->accept(this);
}
for (auto stmt : stmts)
stmt->accept(this);
Expand All @@ -52,7 +50,7 @@ class TypeCheck : public IRVisitor {
void visit(AtomicOpStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
if (stmt->val->ret_type.data_type != stmt->dest->ret_type.data_type) {
TI_WARN("Atomic add ({} to {}) may lose precision.",
TI_WARN("[{}] Atomic add ({} to {}) may lose precision.", stmt->name(),
data_type_name(stmt->val->ret_type.data_type),
data_type_name(stmt->dest->ret_type.data_type));
stmt->val = insert_type_cast_before(stmt, stmt->val,
Expand Down Expand Up @@ -84,10 +82,9 @@ class TypeCheck : public IRVisitor {
}
if (stmt->ptr->ret_type.data_type != common_container_type) {
TI_WARN(
"Local store may lose precision (target = {}, value = {}, "
"stmt_id = {}) at",
stmt->ptr->ret_data_type_name(), stmt->data->ret_data_type_name(),
stmt->id);
"[{}] Local store may lose precision (target = {}, value = {}, at",
stmt->name(), stmt->ptr->ret_data_type_name(),
stmt->data->ret_data_type_name(), stmt->id);
fmt::print(stmt->tb);
}
stmt->ret_type = stmt->ptr->ret_type;
Expand All @@ -108,11 +105,11 @@ class TypeCheck : public IRVisitor {
if (stmt->snodes)
stmt->ret_type.data_type = stmt->snodes[0]->dt;
else
TI_WARN("Type inference failed: snode is nullptr.");
TI_WARN("[{}] Type inference failed: snode is nullptr.", stmt->name());
for (int l = 0; l < stmt->snodes.size(); l++) {
if (stmt->snodes[l]->parent->num_active_indices != 0 &&
stmt->snodes[l]->parent->num_active_indices != stmt->indices.size()) {
TI_ERROR("{} has {} indices. Indexed with {}.",
TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
stmt->snodes[l]->parent->node_type_name,
stmt->snodes[l]->parent->num_active_indices,
stmt->indices.size());
Expand All @@ -121,10 +118,11 @@ class TypeCheck : public IRVisitor {
for (int i = 0; i < stmt->indices.size(); i++) {
TI_ASSERT_INFO(
is_integral(stmt->indices[i]->ret_type.data_type),
"Taichi tensors must be accessed with integral indices (e.g., "
"[{}] Taichi tensors must be accessed with integral indices (e.g., "
"i32/i64). It seems that you have used a float point number as "
"an index. You can cast that to an integer using int(). Also note "
"that ti.floor(ti.f32) returns f32.");
"that ti.floor(ti.f32) returns f32.",
stmt->name());
TI_ASSERT(stmt->indices[i]->ret_type.width == stmt->snodes.size());
}
}
Expand All @@ -138,23 +136,16 @@ class TypeCheck : public IRVisitor {
stmt->ptr->ret_type.data_type);
}
if (stmt->ptr->ret_type.data_type != promoted) {
TI_WARN("Global store may lose precision: {} <- {}, at",
stmt->ptr->ret_data_type_name(), input_type, stmt->tb);
TI_WARN("[{}] Global store may lose precision: {} <- {}, at",
stmt->name(), stmt->ptr->ret_data_type_name(), input_type,
stmt->tb);
}
stmt->ret_type = stmt->ptr->ret_type;
}

void visit(RangeForStmt *stmt) {
/*
TI_ASSERT(block->local_variables.find(stmt->loop_var) ==
block->local_variables.end());
*/
mark_as_if_const(stmt->begin, VectorType(1, DataType::i32));
mark_as_if_const(stmt->end, VectorType(1, DataType::i32));
/*
block->local_variables.insert(
std::make_pair(stmt->loop_var, VectorType(1, DataType::i32)));
*/
stmt->body->accept(this);
}

Expand All @@ -173,13 +164,14 @@ class TypeCheck : public IRVisitor {
}
if (is_trigonometric(stmt->op_type) &&
!is_real(stmt->operand->ret_type.data_type)) {
TI_ERROR("Trigonometric operator takes real inputs only. At {}",
stmt->tb);
TI_ERROR("[{}] Trigonometric operator takes real inputs only. At {}",
stmt->name(), stmt->tb);
}
if ((stmt->op_type == UnaryOpType::floor ||
stmt->op_type == UnaryOpType::ceil) &&
!is_real(stmt->operand->ret_type.data_type)) {
TI_ERROR("floor/ceil takes real inputs only. At {}", stmt->tb);
TI_ERROR("[{}] floor/ceil takes real inputs only. At {}", stmt->name(),
stmt->tb);
}
}

Expand Down Expand Up @@ -215,15 +207,17 @@ class TypeCheck : public IRVisitor {
void visit(BinaryOpStmt *stmt) {
auto error = [&](std::string comment = "") {
if (comment == "") {
TI_WARN("Error: type mismatch (left = {}, right = {}, stmt_id = {}) at",
stmt->lhs->ret_data_type_name(),
stmt->rhs->ret_data_type_name(), stmt->id);
TI_WARN(
"[{}] Error: type mismatch (left = {}, right = {}, stmt_id = {}) "
"at",
stmt->name(), stmt->lhs->ret_data_type_name(),
stmt->rhs->ret_data_type_name(), stmt->id);
} else {
TI_WARN(comment + " at");
TI_WARN("[{}] {} at", stmt->name(), comment);
}
fmt::print(stmt->tb);
TI_WARN("Compilation stopped due to type mismatch.");
exit(-1);
throw std::runtime_error("Binary operator type mismatch");
};
if (stmt->lhs->ret_type.data_type == DataType::unknown &&
stmt->rhs->ret_type.data_type == DataType::unknown)
Expand Down Expand Up @@ -414,10 +408,6 @@ namespace irpass {
void typecheck(IRNode *root) {
analysis::check_fields_registered(root);
TypeCheck::run(root);
// if (root->is<Block>() && root->as<Block>()->parent == nullptr) {
// fix_block_parents(root); // hot fix
// verify(root);
// }
}

} // namespace irpass
Expand Down
12 changes: 6 additions & 6 deletions taichi/util/image_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ template <typename T>
void Array2D<T>::load_image(const std::string &filename, bool linearize) {
int channels;
FILE *f = fopen(filename.c_str(), "rb");
assert_info(f != nullptr, "Image file not found: " + filename);
TI_ASSERT_INFO(f != nullptr, "Image file not found: " + filename);
stbi_ldr_to_hdr_gamma(1.0_f);
float32 *data =
stbi_loadf(filename.c_str(), &this->res[0], &this->res[1], &channels, 0);
assert_info(data != nullptr,
"Image file load failed: " + filename +
" # Msg: " + std::string(stbi_failure_reason()));
assert_info(channels == 1 || channels == 3 || channels == 4,
"Image must have channel 1, 3 or 4: " + filename);
TI_ASSERT_INFO(data != nullptr,
"Image file load failed: " + filename +
" # Msg: " + std::string(stbi_failure_reason()));
TI_ASSERT_INFO(channels == 1 || channels == 3 || channels == 4,
"Image must have channel 1, 3 or 4: " + filename);
this->initialize(Vector2i(this->res[0], this->res[1]));

for (int i = 0; i < this->res[0]; i++) {
Expand Down

0 comments on commit 588e504

Please sign in to comment.