Skip to content

Commit

Permalink
separate emit_extra_unary
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 22, 2019
1 parent 6184d07 commit 314272e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 33 deletions.
3 changes: 1 addition & 2 deletions lang/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ int printf(const char *, ...);
return std::F(x); \
}

// sin and cos are already included in llvm intrinsics
DEFINE_UNARY_REAL_FUNC(exp)
DEFINE_UNARY_REAL_FUNC(log)
DEFINE_UNARY_REAL_FUNC(sin)
DEFINE_UNARY_REAL_FUNC(cos)
DEFINE_UNARY_REAL_FUNC(tan)
DEFINE_UNARY_REAL_FUNC(tanh)
DEFINE_UNARY_REAL_FUNC(abs)
Expand Down
66 changes: 37 additions & 29 deletions lang/src/backends/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,42 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
stmt->ret_data_type_name());
}

virtual void emit_extra_unary(UnaryOpStmt *stmt) {
auto input = stmt->operand->value;
auto input_taichi_type = stmt->operand->ret_type.data_type;
auto input_type = input->getType();
auto op = stmt->op_type;

#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type == DataType::f32) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_f32"), input); \
} else if (input_taichi_type == DataType::f64) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_f64"), input); \
} else if (input_taichi_type == DataType::i32) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_i32"), input); \
} else { \
TC_NOT_IMPLEMENTED \
} \
}
if (false) {
}
UNARY_STD(abs)
UNARY_STD(exp)
UNARY_STD(log)
UNARY_STD(tan)
UNARY_STD(tanh)
UNARY_STD(sgn)
else {
TC_P(unary_op_type_name(op));
TC_NOT_IMPLEMENTED
}
#undef UNARY_STD
}

void visit(UnaryOpStmt *stmt) {
auto input = stmt->operand->value;
auto input_taichi_type = stmt->operand->ret_type.data_type;
Expand All @@ -265,40 +301,12 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
stmt->value = builder->CreateFNeg(input, "neg");
}
UNARY_INTRINSIC(sin)
UNARY_INTRINSIC(sin)
UNARY_INTRINSIC(cos)
UNARY_INTRINSIC(sqrt)
UNARY_INTRINSIC(floor)
UNARY_INTRINSIC(ceil)
else emit_extra_unary(stmt);
#undef UNARY_INTRINSIC
#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type == DataType::f32) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_f32"), input); \
} else if (input_taichi_type == DataType::f64) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_f64"), input); \
} else if (input_taichi_type == DataType::i32) { \
stmt->value = \
builder->CreateCall(get_runtime_function(#x "_i32"), input); \
} else { \
TC_NOT_IMPLEMENTED \
} \
}
UNARY_STD(abs)
UNARY_STD(exp)
UNARY_STD(log)
UNARY_STD(sin)
UNARY_STD(cos)
UNARY_STD(tan)
UNARY_STD(tanh)
UNARY_STD(sgn)
#undef UNARY_STD
else {
TC_P(unary_op_type_name(op));
TC_NOT_IMPLEMENTED
}
} else {
// op = cast
if (stmt->cast_by_value) {
Expand Down
2 changes: 1 addition & 1 deletion lang/src/taichi_llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<llvm::Module> TaichiLLVMContext::clone_runtime_module() {

bool failed = llvm::Linker::linkModules(
*runtime_module, llvm::CloneModule(*libdevice_module));
runtime_module->print(llvm::errs(), nullptr);
// runtime_module->print(llvm::errs(), nullptr);
if (failed) {
TC_ERROR("CUDA libdevice linking failure.");
}
Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_ad_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

@ti.program_test
def grad_test(tifunc, npfunc=None):
ti.cfg.print_ir = True
if npfunc is None:
npfunc = tifunc

Expand Down Expand Up @@ -54,9 +55,9 @@ def test_poly():
grad_test(lambda x: (x - 3) * (x - 1) + x * x)

def test_trigonometric():
grad_test(lambda x: ti.tanh(x), lambda x: np.tanh(x))
grad_test(lambda x: ti.sin(x), lambda x: np.sin(x))
grad_test(lambda x: ti.cos(x), lambda x: np.cos(x))
grad_test(lambda x: ti.tanh(x), lambda x: np.tanh(x))

def test_frac():
grad_test(lambda x: 1 / x)
Expand Down

0 comments on commit 314272e

Please sign in to comment.