Skip to content

Commit

Permalink
[type] [bug] Add rounding for atomic adding of CustomFloatType (taich…
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanke98 authored Dec 15, 2020
1 parent dba8d6b commit 32eb284
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
63 changes: 29 additions & 34 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,21 +1039,39 @@ llvm::Value *CodeGenLLVM::atomic_add_custom_float(AtomicOpStmt *stmt,
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[stmt->dest], byte_ptr, bit_offset);
auto cit = cft->get_digits_type()->as<CustomIntType>();
auto val_store = float_to_custom_int(cft, cit, llvm_val[stmt->val]);
auto physical_type = cit->get_physical_type();
auto compute_type = cft->get_compute_type();

auto s = builder->CreateFPCast(
llvm::ConstantFP::get(*llvm_context,
llvm::APFloat(1.0 / cft->get_scale())),
llvm_type(compute_type));
auto val_scaled = builder->CreateFMul(llvm_val[stmt->val], s);
auto val_scaled_int =
builder->CreateFPToSI(val_scaled, llvm_type(physical_type));

return create_call(
fmt::format("atomic_add_partial_bits_b{}", data_type_bits(physical_type)),
{builder->CreateBitCast(byte_ptr, llvm_ptr_type(physical_type)),
bit_offset, tlctx->get_constant(cit->get_num_bits()), val_scaled_int});
bit_offset, tlctx->get_constant(cit->get_num_bits()), val_store});
}

llvm::Value *CodeGenLLVM::float_to_custom_int(CustomFloatType *cft,
CustomIntType *cit,
llvm::Value *real) {
llvm::Value *s = nullptr;

// Compute int(input * (1.0 / scale) + 0.5)
auto s_numeric = 1.0 / cft->get_scale();
auto compute_type = cft->get_compute_type();
s = builder->CreateFPCast(
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(s_numeric)),
llvm_type(compute_type));
auto input_real = builder->CreateFPCast(real, llvm_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

// Add/minus the 0.5 offset for rounding
scaled = create_call(
fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)),
{scaled});

if (cit->get_is_signed()) {
return builder->CreateFPToSI(scaled, llvm_type(cit->get_compute_type()));
} else {
return builder->CreateFPToUI(scaled, llvm_type(cit->get_compute_type()));
}
}

void CodeGenLLVM::visit(AtomicOpStmt *stmt) {
Expand Down Expand Up @@ -1168,30 +1186,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
store_value = llvm_val[stmt->data];
} else if (auto cft = pointee_type->cast<CustomFloatType>()) {
cit = cft->get_digits_type()->as<CustomIntType>();
llvm::Value *s = nullptr;

// Compute int(input * (1.0 / scale) + 0.5)
auto s_numeric = 1.0 / cft->get_scale();
auto compute_type = cft->get_compute_type();
s = builder->CreateFPCast(
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(s_numeric)),
llvm_type(compute_type));
auto input_real =
builder->CreateFPCast(llvm_val[stmt->data], llvm_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

// Add/minus the 0.5 offset for rounding
scaled = create_call(
fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)),
{scaled});

if (cit->get_is_signed()) {
store_value =
builder->CreateFPToSI(scaled, llvm_type(cit->get_compute_type()));
} else {
store_value =
builder->CreateFPToUI(scaled, llvm_type(cit->get_compute_type()));
}
store_value = float_to_custom_int(cft, cit, llvm_val[stmt->data]);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
4 changes: 4 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *atomic_add_custom_float(AtomicOpStmt *stmt,
CustomFloatType *cft);

llvm::Value *float_to_custom_int(CustomFloatType *cft,
CustomIntType *cit,
llvm::Value *real);

void visit(AtomicOpStmt *stmt) override;

void visit(GlobalPtrStmt *stmt) override;
Expand Down

0 comments on commit 32eb284

Please sign in to comment.