Skip to content

Commit

Permalink
[IR] Support frontend type inference in simple cases (taichi-dev#3302)
Browse files Browse the repository at this point in the history
* Add ret_type in Expression

* Fix Expr::Var

* Remove dt in test_matrix

* Support Const, ArgLoad, Rand, Id, BinaryOp; Add C++ tests; Remove dt in some Python tests

* Remove boilerplate test

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Nov 1, 2021
1 parent 3b179fb commit 68bea56
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 13 deletions.
7 changes: 7 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import taichi.lang
from taichi.core import ti_core
from taichi.lang import expr, impl
from taichi.lang import kernel_impl as kern_mod
from taichi.lang import ops as ops_mod
Expand Down Expand Up @@ -71,6 +72,12 @@ def __init__(self,
dt = impl.get_runtime().default_ip
elif isinstance(n[0], float):
dt = impl.get_runtime().default_fp
elif isinstance(n[0], expr.Expr):
dt = n[0].ptr.get_ret_type()
if dt == ti_core.DataType_unknown:
raise TypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
else:
raise Exception(
'dt required when using dynamic_index for local tensor'
Expand Down
5 changes: 5 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ std::string Expr::get_attribute(const std::string &key) const {
return expr->get_attribute(key);
}

DataType Expr::get_ret_type() const {
return expr->ret_type;
}

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) {
return Expr::make<TernaryOpExpression>(TernaryOpType::select, cond, true_val,
false_val);
Expand Down Expand Up @@ -192,6 +196,7 @@ Expr Var(const Expr &x) {
std::static_pointer_cast<IdExpression>(var.expr)->id,
PrimitiveType::unknown));
var = x;
var->ret_type = x->ret_type;
return var;
}

Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class Expr {
void set_attribute(const std::string &key, const std::string &value);

std::string get_attribute(const std::string &key) const;

DataType get_ret_type() const;
};

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val);
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Expression {
Stmt *stmt;
std::string tb;
std::map<std::string, std::string> attributes;
DataType ret_type;

struct FlattenContext {
VecStatement stmts;
Expand Down
32 changes: 32 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back(std::move(unary));
}

BinaryOpExpression::BinaryOpExpression(const BinaryOpType &type,
const Expr &lhs,
const Expr &rhs)
: type(type) {
this->lhs.set(load_if_ptr(lhs));
this->rhs.set(load_if_ptr(rhs));
auto lhs_type = this->lhs->ret_type;
auto rhs_type = this->rhs->ret_type;
// TODO: report error messages for unsuccessful inference
if (!lhs_type->is<PrimitiveType>() || !rhs_type->is<PrimitiveType>())
return;
if (lhs_type == PrimitiveType::unknown || rhs_type == PrimitiveType::unknown)
return;
if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
return;
if (is_comparison(type)) {
ret_type = PrimitiveType::i32;
return;
}
if (type == BinaryOpType::truediv) {
auto default_fp = get_current_program().config.default_fp;
if (!is_real(lhs_type)) {
lhs_type = default_fp;
}
if (!is_real(rhs_type)) {
rhs_type = default_fp;
}
}
ret_type = promoted_type(lhs_type, rhs_type);
}

void BinaryOpExpression::flatten(FlattenContext *ctx) {
// if (stmt)
// return;
Expand Down
11 changes: 6 additions & 5 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class ArgLoadExpression : public Expression {
DataType dt;

ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) {
ret_type = dt;
}

void serialize(std::ostream &ss) override {
Expand All @@ -254,6 +255,7 @@ class RandExpression : public Expression {
DataType dt;

RandExpression(DataType dt) : dt(dt) {
ret_type = dt;
}

void serialize(std::ostream &ss) override {
Expand Down Expand Up @@ -286,11 +288,9 @@ class BinaryOpExpression : public Expression {
BinaryOpType type;
Expr lhs, rhs;

BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type) {
this->lhs.set(load_if_ptr(lhs));
this->rhs.set(load_if_ptr(rhs));
}
BinaryOpExpression(const BinaryOpType &type,
const Expr &lhs,
const Expr &rhs);

void serialize(std::ostream &ss) override {
ss << '(';
Expand Down Expand Up @@ -677,6 +677,7 @@ class ConstExpression : public Expression {

template <typename T>
ConstExpression(const T &x) : val(x) {
ret_type = val.dt;
}

void serialize(std::ostream &ss) override {
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ void export_lang(py::module &m) {
})
.def("set_grad", &Expr::set_grad)
.def("set_attribute", &Expr::set_attribute)
.def("get_ret_type", &Expr::get_ret_type)
.def("get_expr_name",
[](Expr *expr) {
return expr->cast<GlobalVariableExpression>()->name;
Expand Down
44 changes: 44 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "gtest/gtest.h"

#include "taichi/ir/frontend_ir.h"
#include "taichi/program/program.h"

namespace taichi {
namespace lang {

TEST(FrontendTypeInference, Const) {
auto const_i64 = Expr::make<ConstExpression, int64>(1LL << 63);
EXPECT_EQ(const_i64->ret_type, PrimitiveType::i64);
}

TEST(FrontendTypeInference, ArgLoad) {
auto arg_load_u64 = Expr::make<ArgLoadExpression>(2, PrimitiveType::u64);
EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64);
}

TEST(FrontendTypeInference, Rand) {
auto rand_f16 = Expr::make<RandExpression>(PrimitiveType::f16);
EXPECT_EQ(rand_f16->ret_type, PrimitiveType::f16);
}

TEST(FrontendTypeInference, Id) {
auto prog = std::make_unique<Program>(Arch::x64);
auto func = []() {};
auto kernel = std::make_unique<Kernel>(*prog, func, "fake_kernel");
Callable::CurrentCallableGuard _(kernel->program, kernel.get());
auto const_i32 = Expr::make<ConstExpression, int32>(-(1 << 20));
auto id_i32 = Var(const_i32);
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
}

TEST(FrontendTypeInference, BinaryOp) {
auto prog = std::make_unique<Program>(Arch::x64);
prog->config.default_fp = PrimitiveType::f64;
auto const_i32 = Expr::make<ConstExpression, int32>(-(1 << 20));
auto const_f32 = Expr::make<ConstExpression, float32>(5.0);
auto truediv_f64 = expr_truediv(const_i32, const_f32);
EXPECT_EQ(truediv_f64->ret_type, PrimitiveType::f64);
}

} // namespace lang
} // namespace taichi
2 changes: 1 addition & 1 deletion tests/python/test_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def init():
for i in x:
x[i] = ti.Vector(
[ti.random() * 0.1 + 0.5,
ti.random() * 0.1 + 0.5], dt=ti.f32)
ti.random() * 0.1 + 0.5])

init()

Expand Down
12 changes: 7 additions & 5 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def func2(b: ti.any_arr(element_dim=1, layout=ti.Layout.SOA)):
assert v[9][1] == 9


@ti.test(require=ti.extension.dynamic_index, dynamic_index=True)
@ti.test(require=ti.extension.dynamic_index, dynamic_index=True, debug=True)
def test_matrix_non_constant_index():
m = ti.Matrix.field(2, 2, ti.i32, 5)
v = ti.Vector.field(10, ti.i32, 5)
Expand Down Expand Up @@ -221,10 +221,10 @@ def func2():

@ti.kernel
def func3():
tmp = ti.Vector([1, 2, 3], dt=ti.i32)
tmp = ti.Vector([1, 2, 3])
for i in range(3):
tmp[i] = i * i
vec = ti.Vector([4, 5, 6], dt=ti.i32)
vec = ti.Vector([4, 5, 6])
for j in range(3):
vec[tmp[i] % 3] += vec[j % 3]
assert tmp[0] == 0
Expand All @@ -236,9 +236,11 @@ def func3():
@ti.kernel
def func4(k: ti.i32):
tmp = ti.Vector([k, k * 2, k * 3])
assert tmp[0] == k
assert tmp[1] == k * 2
assert tmp[2] == k * 3

with pytest.raises(Exception):
func4(10)
func4(10)


@ti.test(arch=ti.cpu)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_random_2d_dist():
@ti.kernel
def gen():
for i in range(n):
x[i] = ti.Vector([ti.random(), ti.random()], dt=ti.f32)
x[i] = ti.Vector([ti.random(), ti.random()])

gen()

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_random_vector_dup_eval():

@ti.kernel
def func():
a[None] = ti.Vector([ti.random(), 1], dt=ti.f32).normalized()
a[None] = ti.Vector([ti.random(), 1]).normalized()

for i in range(4):
func()
Expand Down

0 comments on commit 68bea56

Please sign in to comment.