From 588e504b02d6ca6b314d89fb2cd7fc0e41e70467 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Thu, 16 Apr 2020 14:50:47 -0400 Subject: [PATCH] [misc] Improve type checking error messages (#797) --- taichi/common/dict.h | 6 +- taichi/common/interface.h | 237 ++++++++++++++++--------------- taichi/common/task.h | 2 +- taichi/common/util.h | 10 +- taichi/math/scalar.h | 3 +- taichi/python/export_math.cpp | 4 +- taichi/system/profiler.h | 2 +- taichi/transforms/type_check.cpp | 62 ++++---- taichi/util/image_buffer.cpp | 12 +- 9 files changed, 166 insertions(+), 172 deletions(-) diff --git a/taichi/common/dict.h b/taichi/common/dict.h index 1ef7b444c7006..78cdecfe925d8 100644 --- a/taichi/common/dict.h +++ b/taichi/common/dict.h @@ -176,8 +176,8 @@ class Dict { int64 ptr_ll; std::getline(ss, t, '\t'); ss >> ptr_ll; - assert_info(t == typeid(T).name(), - "Pointer type mismatch: " + t + " and " + typeid(T).name()); + TI_ASSERT_INFO(t == typeid(T).name(), + "Pointer type mismatch: " + t + " and " + typeid(T).name()); return reinterpret_cast(ptr_ll); } @@ -347,7 +347,7 @@ inline bool Dict::get(std::string key) const { {"true", true}, {"True", true}, {"t", true}, {"1", true}, {"false", false}, {"False", false}, {"f", false}, {"0", false}, }; - assert_info(dict.find(s) != dict.end(), "Unkown identifer for bool: " + s); + TI_ASSERT_INFO(dict.find(s) != dict.end(), "Unkown identifer for bool: " + s); return dict[s]; } diff --git a/taichi/common/interface.h b/taichi/common/interface.h index 822e7e815321c..fde608c706e3f 100644 --- a/taichi/common/interface.h +++ b/taichi/common/interface.h @@ -113,122 +113,127 @@ class InterfaceHolder { } }; -#define TI_INTERFACE(T) \ - extern void *get_implementation_holder_instance_##T(); \ - class TI_IMPLEMENTATION_HOLDER_NAME(T) final \ - : public ImplementationHolderBase { \ - public: \ - TI_IMPLEMENTATION_HOLDER_NAME(T)(const std::string &name) { \ - this->name = name; \ - } \ - using FactoryMethod = std::function()>; \ - using FactoryUniqueMethod = std::function()>; \ - using FactoryUniqueCtorMethod = \ - std::function(const Dict &config)>; \ - using FactoryRawMethod = std::function; \ - using FactoryPlacementMethod = std::function; \ - std::map implementation_factories; \ - std::map \ - implementation_unique_factories; \ - std::map \ - implementation_unique_ctor_factories; \ - std::map implementation_raw_factories; \ - std::map \ - implementation_placement_factories; \ - std::vector get_implementation_names() const override { \ - std::vector names; \ - for (auto &kv : implementation_factories) { \ - names.push_back(kv.first); \ - } \ - return names; \ - } \ - template \ - void insert(const std::string &alias) { \ - implementation_factories.insert( \ - std::make_pair(alias, [&]() { return std::make_shared(); })); \ - implementation_unique_factories.insert( \ - std::make_pair(alias, [&]() { return std::make_unique(); })); \ - implementation_raw_factories.insert( \ - std::make_pair(alias, [&]() { return new G(); })); \ - implementation_placement_factories.insert(std::make_pair( \ - alias, [&](void *place) { return new (place) G(); })); \ - } \ - template \ - void insert_new(const std::string &alias) { \ - /*with ctor*/ \ - implementation_factories.insert( \ - std::make_pair(alias, [&]() { return std::make_shared(); })); \ - implementation_unique_factories.insert( \ - std::make_pair(alias, [&]() { return std::make_unique(); })); \ - implementation_unique_ctor_factories.insert(std::make_pair( \ - alias, \ - [&](const Dict &config) { return std::make_unique(config); })); \ - implementation_raw_factories.insert( \ - std::make_pair(alias, [&]() { return new G(); })); \ - implementation_placement_factories.insert(std::make_pair( \ - alias, [&](void *place) { return new (place) G(); })); \ - } \ - void insert(const std::string &alias, const FactoryMethod &f) { \ - implementation_factories.insert(std::make_pair(alias, f)); \ - } \ - bool has(const std::string &alias) const override { \ - return implementation_factories.find(alias) != \ - implementation_factories.end(); \ - } \ - void remove(const std::string &alias) override { \ - assert_info(has(alias), \ - std::string("Implemetation ") + alias + " not found!"); \ - implementation_factories.erase(alias); \ - } \ - void update(const std::string &alias, const FactoryMethod &f) { \ - if (has(alias)) { \ - remove(alias); \ - } \ - insert(alias, f); \ - } \ - template \ - void update(const std::string &alias) { \ - if (has(alias)) { \ - remove(alias); \ - } \ - insert(alias); \ - } \ - std::shared_ptr create(const std::string &alias) { \ - auto factory = implementation_factories.find(alias); \ - assert_info(factory != implementation_factories.end(), \ - "Implementation [" + name + "::" + alias + "] not found!"); \ - return (factory->second)(); \ - } \ - std::unique_ptr create_unique(const std::string &alias) { \ - auto factory = implementation_unique_factories.find(alias); \ - assert_info(factory != implementation_unique_factories.end(), \ - "Implementation [" + name + "::" + alias + "] not found!"); \ - return (factory->second)(); \ - } \ - std::unique_ptr create_unique_ctor(const std::string &alias, \ - const Dict &config) { \ - auto factory = implementation_unique_ctor_factories.find(alias); \ - assert_info(factory != implementation_unique_ctor_factories.end(), \ - "Implementation [" + name + "::" + alias + "] not found!"); \ - return (factory->second)(config); \ - } \ - T *create_raw(const std::string &alias) { \ - auto factory = implementation_raw_factories.find(alias); \ - assert_info(factory != implementation_raw_factories.end(), \ - "Implementation [" + name + "::" + alias + "] not found!"); \ - return (factory->second)(); \ - } \ - T *create_placement(const std::string &alias, void *place) { \ - auto factory = implementation_placement_factories.find(alias); \ - assert_info(factory != implementation_placement_factories.end(), \ - "Implementation [" + name + "::" + alias + "] not found!"); \ - return (factory->second)(place); \ - } \ - static TI_IMPLEMENTATION_HOLDER_NAME(T) * get_instance() { \ - return static_cast( \ - get_implementation_holder_instance_##T()); \ - } \ - }; \ +#define TI_INTERFACE(T) \ + extern void *get_implementation_holder_instance_##T(); \ + class TI_IMPLEMENTATION_HOLDER_NAME(T) final \ + : public ImplementationHolderBase { \ + public: \ + TI_IMPLEMENTATION_HOLDER_NAME(T)(const std::string &name) { \ + this->name = name; \ + } \ + using FactoryMethod = std::function()>; \ + using FactoryUniqueMethod = std::function()>; \ + using FactoryUniqueCtorMethod = \ + std::function(const Dict &config)>; \ + using FactoryRawMethod = std::function; \ + using FactoryPlacementMethod = std::function; \ + std::map implementation_factories; \ + std::map \ + implementation_unique_factories; \ + std::map \ + implementation_unique_ctor_factories; \ + std::map implementation_raw_factories; \ + std::map \ + implementation_placement_factories; \ + std::vector get_implementation_names() const override { \ + std::vector names; \ + for (auto &kv : implementation_factories) { \ + names.push_back(kv.first); \ + } \ + return names; \ + } \ + template \ + void insert(const std::string &alias) { \ + implementation_factories.insert( \ + std::make_pair(alias, [&]() { return std::make_shared(); })); \ + implementation_unique_factories.insert( \ + std::make_pair(alias, [&]() { return std::make_unique(); })); \ + implementation_raw_factories.insert( \ + std::make_pair(alias, [&]() { return new G(); })); \ + implementation_placement_factories.insert(std::make_pair( \ + alias, [&](void *place) { return new (place) G(); })); \ + } \ + template \ + void insert_new(const std::string &alias) { \ + /*with ctor*/ \ + implementation_factories.insert( \ + std::make_pair(alias, [&]() { return std::make_shared(); })); \ + implementation_unique_factories.insert( \ + std::make_pair(alias, [&]() { return std::make_unique(); })); \ + implementation_unique_ctor_factories.insert(std::make_pair( \ + alias, \ + [&](const Dict &config) { return std::make_unique(config); })); \ + implementation_raw_factories.insert( \ + std::make_pair(alias, [&]() { return new G(); })); \ + implementation_placement_factories.insert(std::make_pair( \ + alias, [&](void *place) { return new (place) G(); })); \ + } \ + void insert(const std::string &alias, const FactoryMethod &f) { \ + implementation_factories.insert(std::make_pair(alias, f)); \ + } \ + bool has(const std::string &alias) const override { \ + return implementation_factories.find(alias) != \ + implementation_factories.end(); \ + } \ + void remove(const std::string &alias) override { \ + TI_ASSERT_INFO(has(alias), \ + std::string("Implemetation ") + alias + " not found!"); \ + implementation_factories.erase(alias); \ + } \ + void update(const std::string &alias, const FactoryMethod &f) { \ + if (has(alias)) { \ + remove(alias); \ + } \ + insert(alias, f); \ + } \ + template \ + void update(const std::string &alias) { \ + if (has(alias)) { \ + remove(alias); \ + } \ + insert(alias); \ + } \ + std::shared_ptr create(const std::string &alias) { \ + auto factory = implementation_factories.find(alias); \ + TI_ASSERT_INFO( \ + factory != implementation_factories.end(), \ + "Implementation [" + name + "::" + alias + "] not found!"); \ + return (factory->second)(); \ + } \ + std::unique_ptr create_unique(const std::string &alias) { \ + auto factory = implementation_unique_factories.find(alias); \ + TI_ASSERT_INFO( \ + factory != implementation_unique_factories.end(), \ + "Implementation [" + name + "::" + alias + "] not found!"); \ + return (factory->second)(); \ + } \ + std::unique_ptr create_unique_ctor(const std::string &alias, \ + const Dict &config) { \ + auto factory = implementation_unique_ctor_factories.find(alias); \ + TI_ASSERT_INFO( \ + factory != implementation_unique_ctor_factories.end(), \ + "Implementation [" + name + "::" + alias + "] not found!"); \ + return (factory->second)(config); \ + } \ + T *create_raw(const std::string &alias) { \ + auto factory = implementation_raw_factories.find(alias); \ + TI_ASSERT_INFO( \ + factory != implementation_raw_factories.end(), \ + "Implementation [" + name + "::" + alias + "] not found!"); \ + return (factory->second)(); \ + } \ + T *create_placement(const std::string &alias, void *place) { \ + auto factory = implementation_placement_factories.find(alias); \ + TI_ASSERT_INFO( \ + factory != implementation_placement_factories.end(), \ + "Implementation [" + name + "::" + alias + "] not found!"); \ + return (factory->second)(place); \ + } \ + static TI_IMPLEMENTATION_HOLDER_NAME(T) * get_instance() { \ + return static_cast( \ + get_implementation_holder_instance_##T()); \ + } \ + }; \ extern TI_IMPLEMENTATION_HOLDER_NAME(T) * TI_IMPLEMENTATION_HOLDER_PTR(T); #define TI_INTERFACE_DEF(class_name, base_alias) \ diff --git a/taichi/common/task.h b/taichi/common/task.h index b178ff307af1a..4da0e30f7a81e 100644 --- a/taichi/common/task.h +++ b/taichi/common/task.h @@ -14,7 +14,7 @@ TI_NAMESPACE_BEGIN class Task : public Unit { public: virtual std::string run(const std::vector ¶meters) { - assert_info(parameters.size() == 0, "No parameters supported."); + TI_ASSERT_INFO(parameters.size() == 0, "No parameters supported."); return this->run(); } diff --git a/taichi/common/util.h b/taichi/common/util.h index b3df178d57642..d65b984f55976 100644 --- a/taichi/common/util.h +++ b/taichi/common/util.h @@ -140,17 +140,15 @@ static_assert(__cplusplus >= 201402L, "C++14 required."); #define DEBUG_TRIGGER #endif -#define assert_info(x, info) \ +#define TI_STATIC_ASSERT(x) static_assert((x), #x); +#define TI_ASSERT(x) TI_ASSERT_INFO((x), #x) +#define TI_ASSERT_INFO(x, ...) \ { \ bool ___ret___ = static_cast(x); \ if (!___ret___) { \ - TI_ERROR(info); \ + TI_ERROR(__VA_ARGS__); \ } \ } - -#define TI_STATIC_ASSERT(x) static_assert((x), #x); -#define TI_ASSERT(x) TI_ASSERT_INFO((x), #x) -#define TI_ASSERT_INFO assert_info #define TI_NOT_IMPLEMENTED TI_ERROR("Not supported."); #define TI_NAMESPACE_BEGIN namespace taichi { diff --git a/taichi/math/scalar.h b/taichi/math/scalar.h index 2b606e542c3dc..0643dbfd8529e 100644 --- a/taichi/math/scalar.h +++ b/taichi/math/scalar.h @@ -159,7 +159,8 @@ TI_FORCE_INLINE bool abnormal(T m) noexcept { } inline int64 get_largest_pot(int64 a) noexcept { - assert_info(a > 0, "a should be positive, instead of " + std::to_string(a)); + TI_ASSERT_INFO(a > 0, + "a should be positive, instead of " + std::to_string(a)); // TODO: optimize int64 i = 1; while (i <= a / 2) { diff --git a/taichi/python/export_math.cpp b/taichi/python/export_math.cpp index acaccecaf6d6c..76e5fb2ca76c7 100644 --- a/taichi/python/export_math.cpp +++ b/taichi/python/export_math.cpp @@ -76,8 +76,8 @@ void array2d_to_ndarray(Array2D *arr, uint64 output) // 'output' is actually a pointer... { int width = arr->get_width(), height = arr->get_height(); - assert_info(width > 0, ""); - assert_info(height > 0, ""); + TI_ASSERT(width > 0); + TI_ASSERT(height > 0); for (auto &ind : arr->get_region()) { for (int k = 0; k < channels; k++) { const Vector3 entry = (*arr)[ind]; diff --git a/taichi/system/profiler.h b/taichi/system/profiler.h index 3c193b58113c7..c2cb42677f6f0 100644 --- a/taichi/system/profiler.h +++ b/taichi/system/profiler.h @@ -138,7 +138,7 @@ class Profiler { } void stop() { - assert_info(!stopped, "Profiler already stopped."); + TI_ASSERT_INFO(!stopped, "Profiler already stopped."); float64 elapsed = Time::get_time() - start_time; if ((int64)elements != -1) { ProfilerRecords::get_instance().insert_sample(elapsed, elements); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 70c8437a7053e..b35269265cce8 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -20,9 +20,8 @@ class TypeCheck : public IRVisitor { } void visit(AllocaStmt *stmt) { - // Do nothing. - // Alloca type is determined by the first LocalStore in IR visiting order, - // at compile time + // Do nothing. Alloca type is determined by the first LocalStore in IR + // visiting order, at compile time. // ret_type stands for its element type. stmt->ret_type.set_is_pointer(false); @@ -40,10 +39,9 @@ class TypeCheck : public IRVisitor { void visit(Block *stmt_list) { std::vector stmts; - // Make a copy since type casts may be inserted for type promotion + // Make a copy since type casts may be inserted for type promotion. for (auto &stmt : stmt_list->statements) { stmts.push_back(stmt.get()); - // stmt->accept(this); } for (auto stmt : stmts) stmt->accept(this); @@ -52,7 +50,7 @@ class TypeCheck : public IRVisitor { void visit(AtomicOpStmt *stmt) { TI_ASSERT(stmt->width() == 1); if (stmt->val->ret_type.data_type != stmt->dest->ret_type.data_type) { - TI_WARN("Atomic add ({} to {}) may lose precision.", + TI_WARN("[{}] Atomic add ({} to {}) may lose precision.", stmt->name(), data_type_name(stmt->val->ret_type.data_type), data_type_name(stmt->dest->ret_type.data_type)); stmt->val = insert_type_cast_before(stmt, stmt->val, @@ -84,10 +82,9 @@ class TypeCheck : public IRVisitor { } if (stmt->ptr->ret_type.data_type != common_container_type) { TI_WARN( - "Local store may lose precision (target = {}, value = {}, " - "stmt_id = {}) at", - stmt->ptr->ret_data_type_name(), stmt->data->ret_data_type_name(), - stmt->id); + "[{}] Local store may lose precision (target = {}, value = {}, at", + stmt->name(), stmt->ptr->ret_data_type_name(), + stmt->data->ret_data_type_name(), stmt->id); fmt::print(stmt->tb); } stmt->ret_type = stmt->ptr->ret_type; @@ -108,11 +105,11 @@ class TypeCheck : public IRVisitor { if (stmt->snodes) stmt->ret_type.data_type = stmt->snodes[0]->dt; else - TI_WARN("Type inference failed: snode is nullptr."); + TI_WARN("[{}] Type inference failed: snode is nullptr.", stmt->name()); for (int l = 0; l < stmt->snodes.size(); l++) { if (stmt->snodes[l]->parent->num_active_indices != 0 && stmt->snodes[l]->parent->num_active_indices != stmt->indices.size()) { - TI_ERROR("{} has {} indices. Indexed with {}.", + TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), stmt->snodes[l]->parent->node_type_name, stmt->snodes[l]->parent->num_active_indices, stmt->indices.size()); @@ -121,10 +118,11 @@ class TypeCheck : public IRVisitor { for (int i = 0; i < stmt->indices.size(); i++) { TI_ASSERT_INFO( is_integral(stmt->indices[i]->ret_type.data_type), - "Taichi tensors must be accessed with integral indices (e.g., " + "[{}] Taichi tensors must be accessed with integral indices (e.g., " "i32/i64). It seems that you have used a float point number as " "an index. You can cast that to an integer using int(). Also note " - "that ti.floor(ti.f32) returns f32."); + "that ti.floor(ti.f32) returns f32.", + stmt->name()); TI_ASSERT(stmt->indices[i]->ret_type.width == stmt->snodes.size()); } } @@ -138,23 +136,16 @@ class TypeCheck : public IRVisitor { stmt->ptr->ret_type.data_type); } if (stmt->ptr->ret_type.data_type != promoted) { - TI_WARN("Global store may lose precision: {} <- {}, at", - stmt->ptr->ret_data_type_name(), input_type, stmt->tb); + TI_WARN("[{}] Global store may lose precision: {} <- {}, at", + stmt->name(), stmt->ptr->ret_data_type_name(), input_type, + stmt->tb); } stmt->ret_type = stmt->ptr->ret_type; } void visit(RangeForStmt *stmt) { - /* - TI_ASSERT(block->local_variables.find(stmt->loop_var) == - block->local_variables.end()); - */ mark_as_if_const(stmt->begin, VectorType(1, DataType::i32)); mark_as_if_const(stmt->end, VectorType(1, DataType::i32)); - /* - block->local_variables.insert( - std::make_pair(stmt->loop_var, VectorType(1, DataType::i32))); - */ stmt->body->accept(this); } @@ -173,13 +164,14 @@ class TypeCheck : public IRVisitor { } if (is_trigonometric(stmt->op_type) && !is_real(stmt->operand->ret_type.data_type)) { - TI_ERROR("Trigonometric operator takes real inputs only. At {}", - stmt->tb); + TI_ERROR("[{}] Trigonometric operator takes real inputs only. At {}", + stmt->name(), stmt->tb); } if ((stmt->op_type == UnaryOpType::floor || stmt->op_type == UnaryOpType::ceil) && !is_real(stmt->operand->ret_type.data_type)) { - TI_ERROR("floor/ceil takes real inputs only. At {}", stmt->tb); + TI_ERROR("[{}] floor/ceil takes real inputs only. At {}", stmt->name(), + stmt->tb); } } @@ -215,15 +207,17 @@ class TypeCheck : public IRVisitor { void visit(BinaryOpStmt *stmt) { auto error = [&](std::string comment = "") { if (comment == "") { - TI_WARN("Error: type mismatch (left = {}, right = {}, stmt_id = {}) at", - stmt->lhs->ret_data_type_name(), - stmt->rhs->ret_data_type_name(), stmt->id); + TI_WARN( + "[{}] Error: type mismatch (left = {}, right = {}, stmt_id = {}) " + "at", + stmt->name(), stmt->lhs->ret_data_type_name(), + stmt->rhs->ret_data_type_name(), stmt->id); } else { - TI_WARN(comment + " at"); + TI_WARN("[{}] {} at", stmt->name(), comment); } fmt::print(stmt->tb); TI_WARN("Compilation stopped due to type mismatch."); - exit(-1); + throw std::runtime_error("Binary operator type mismatch"); }; if (stmt->lhs->ret_type.data_type == DataType::unknown && stmt->rhs->ret_type.data_type == DataType::unknown) @@ -414,10 +408,6 @@ namespace irpass { void typecheck(IRNode *root) { analysis::check_fields_registered(root); TypeCheck::run(root); - // if (root->is() && root->as()->parent == nullptr) { - // fix_block_parents(root); // hot fix - // verify(root); - // } } } // namespace irpass diff --git a/taichi/util/image_buffer.cpp b/taichi/util/image_buffer.cpp index 034ad84289642..9d4dde18bef0b 100644 --- a/taichi/util/image_buffer.cpp +++ b/taichi/util/image_buffer.cpp @@ -21,15 +21,15 @@ template void Array2D::load_image(const std::string &filename, bool linearize) { int channels; FILE *f = fopen(filename.c_str(), "rb"); - assert_info(f != nullptr, "Image file not found: " + filename); + TI_ASSERT_INFO(f != nullptr, "Image file not found: " + filename); stbi_ldr_to_hdr_gamma(1.0_f); float32 *data = stbi_loadf(filename.c_str(), &this->res[0], &this->res[1], &channels, 0); - assert_info(data != nullptr, - "Image file load failed: " + filename + - " # Msg: " + std::string(stbi_failure_reason())); - assert_info(channels == 1 || channels == 3 || channels == 4, - "Image must have channel 1, 3 or 4: " + filename); + TI_ASSERT_INFO(data != nullptr, + "Image file load failed: " + filename + + " # Msg: " + std::string(stbi_failure_reason())); + TI_ASSERT_INFO(channels == 1 || channels == 3 || channels == 4, + "Image must have channel 1, 3 or 4: " + filename); this->initialize(Vector2i(this->res[0], this->res[1])); for (int i = 0; i < this->res[0]; i++) {