Skip to content

Commit

Permalink
[te] Add BitCast to the IR (pytorch#49184)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#49184

Adds BitCasting to NNC.  This will enable fast approximation algorithms implemented directly in TensorExpressions

Test Plan: buck test mode/no-gpu //caffe2/test/cpp/tensorexpr:tensorexpr

Reviewed By: bertmaher

Differential Revision: D25466476

fbshipit-source-id: f063ab29ba7bab2dcce463e499f2d4a16bdc1f0e
  • Loading branch information
bwasti authored and facebook-github-bot committed Dec 12, 2020
1 parent 5716b7d commit 6b78644
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 0 deletions.
83 changes: 83 additions & 0 deletions test/cpp/tensorexpr/test_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,63 @@ TEST(LLVM, ByteToDoubleCastTest) {
ASSERT_EQ(cg.value<double>(), 2);
}

TEST(LLVM, BitCast) {
constexpr int16_t ref16 = 1337;
constexpr int32_t ref32 = 1337;
constexpr int64_t ref64 = 1337;
at::Half reff16 = 1337.0f;
constexpr float reff32 = 1337.0f;
constexpr double reff64 = 1337.0f;

// this is broken
/*{
KernelScope kernel_scope;
at::Half k_;
at::Half* k = &k_;
*reinterpret_cast<int16_t*>(k) = ref16;
auto a = HalfImm::make(k);
auto b = BitCast::make(kShort, a);
LLVMExprEval cg(b);
ASSERT_EQ(cg.value<int16_t>(), ref16);
}*/

{
KernelScope kernel_scope;
float k = raw_bitcast<float>(ref32);
auto a = FloatImm::make(k);
auto b = BitCast::make(kInt, a);
LLVMExprEval cg(b);
ASSERT_EQ(cg.value<int32_t>(), ref32);
}

{
KernelScope kernel_scope;
double k = raw_bitcast<double>(ref64);
auto a = DoubleImm::make(k);
auto b = BitCast::make(kLong, a);
LLVMExprEval cg(b);
ASSERT_EQ(cg.value<int64_t>(), ref64);
}

{
KernelScope kernel_scope;
int64_t k = raw_bitcast<int64_t>(reff64);
auto a = LongImm::make(k);
auto b = BitCast::make(kDouble, a);
LLVMExprEval cg(b);
ASSERT_EQ(cg.value<double>(), reff64);
}

{
KernelScope kernel_scope;
int32_t k = raw_bitcast<int32_t>(reff32);
auto a = IntImm::make(k);
auto b = BitCast::make(kFloat, a);
LLVMExprEval cg(b);
ASSERT_EQ(cg.value<float>(), reff32);
}
}

TEST(LLVM, LetTest01) {
KernelScope kernel_scope;

Expand Down Expand Up @@ -514,6 +571,32 @@ TEST(LLVM, VectorizerLoadStoreTest) {
assertAllEqual(c_vec, 21);
}

TEST(LLVM, VectorizeBitCast) {
KernelScope kernel_scope;
Placeholder a(BufHandle("A", {128}, kInt));

Tensor* c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) {
return bitcast<float>(a.load(i));
});

Placeholder c_buf(BufHandle(c->buf()));
LoopNest l({c});
Stmt* s = l.root_stmt();
l.vectorize(dynamic_cast<Block*>(s)->front());
ASSERT_TRUE(dynamic_cast<For*>(dynamic_cast<Block*>(s)->front()) == nullptr);

LLVMCodeGen cg(s, {a, c_buf});

std::vector<int> a_vec(128);
std::vector<float> c_vec(128);
for (auto i = 0; i < 128; ++i) {
a_vec[i] = raw_bitcast<int>(1337.f);
}
std::vector<void*> args({a_vec.data(), c_vec.data()});
ASSERT_EQ(cg.value<int>(args), 0);
assertAllEqual(c_vec, 1337.f);
}

TEST(LLVM, MemcpyTest) {
KernelScope kernel_scope;
constexpr int N = 32;
Expand Down
110 changes: 110 additions & 0 deletions test/cpp/tensorexpr/test_type.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <gtest/gtest.h>

#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include "torch/csrc/jit/tensorexpr/tensor.h"

Expand Down Expand Up @@ -42,6 +43,115 @@ TEST(Type, Test01) {
}
}

TEST(Type, BitCasting) {
{
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ExprHandle y = bitcast<int32_t>(x);
ASSERT_EQ(y.dtype(), kInt);
}
{
KernelScope kernel_scope;
VarHandle x("x", kInt);
ExprHandle y = bitcast<float>(x);
ASSERT_EQ(y.dtype(), kFloat);
}
{
KernelScope kernel_scope;
VarHandle x("x", kShort);
ExprHandle y = bitcast<at::Half>(x);
ASSERT_EQ(y.dtype(), kHalf);
}
{
KernelScope kernel_scope;
VarHandle x("x", kHalf);
ExprHandle y = bitcast<int16_t>(x);
ASSERT_EQ(y.dtype(), kShort);
}

constexpr int16_t ref16 = 1337;
constexpr int32_t ref32 = 1337;
constexpr int64_t ref64 = 1337;
at::Half reff16 = 1337.0f;
constexpr float reff32 = 1337.0f;
constexpr double reff64 = 1337.0f;
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
// this is broken
/*{
KernelScope kernel_scope;
at::Half k_;
at::Half* k = &k_;
*reinterpret_cast<int16_t*>(k) = ref16;
auto a = HalfImm::make(*k);
auto b = BitCast::make(kShort, a);
SimpleIRExprEval cg(b);
ASSERT_EQ(cg.value<int16_t>(), ref16);
}*/

{
KernelScope kernel_scope;
float k = raw_bitcast<float>(ref32);
auto a = FloatImm::make(k);
auto b = BitCast::make(kInt, a);
SimpleIRExprEval cg(b);
ASSERT_EQ(cg.value<int32_t>(), ref32);
}

{
KernelScope kernel_scope;
double k = raw_bitcast<double>(ref64);
auto a = DoubleImm::make(k);
auto b = BitCast::make(kLong, a);
SimpleIRExprEval cg(b);
ASSERT_EQ(cg.value<int64_t>(), ref64);
}

{
KernelScope kernel_scope;
int64_t k = raw_bitcast<int64_t>(reff64);
auto a = LongImm::make(k);
auto b = BitCast::make(kDouble, a);
SimpleIRExprEval cg(b);
ASSERT_EQ(cg.value<double>(), reff64);
}

{
KernelScope kernel_scope;
int32_t k = raw_bitcast<int32_t>(reff32);
auto a = IntImm::make(k);
auto b = BitCast::make(kFloat, a);
SimpleIRExprEval cg(b);
ASSERT_EQ(cg.value<float>(), reff32);
}

// This segfaults :(
/*{
KernelScope kernel_scope;
VarHandle x("x", kDouble);
ASSERT_ANY_THROW(ExprHandle y = bitcast<int32_t>(x));
}
{
KernelScope kernel_scope;
VarHandle x("x", kFloat);
ASSERT_ANY_THROW(ExprHandle y = bitcast<int64_t>(x));
}
{
KernelScope kernel_scope;
VarHandle x("x", kLong);
ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
}
{
KernelScope kernel_scope;
VarHandle x("x", kShort);
ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
}
{
KernelScope kernel_scope;
VarHandle x("x", kInt);
ASSERT_ANY_THROW(ExprHandle y = bitcast<at::Half>(x));
}*/
}

TEST(Type, Propagation) {
// Same types:
{
Expand Down
60 changes: 60 additions & 0 deletions torch/csrc/jit/tensorexpr/eval.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cmath>
#include <cstring>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -124,6 +125,14 @@ inline c10::Half div_value(c10::Half lhs, c10::Half rhs) {
return lhs / rhs;
}

template <typename To, typename From>
To raw_bitcast(const From& src) {
TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation");
To storage;
std::memcpy(&storage, &src, sizeof(From));
return reinterpret_cast<To&>(storage);
}

class SimpleIREvaluator : public CodeGen, public IRVisitor {
public:
template <typename... Ts>
Expand Down Expand Up @@ -573,6 +582,57 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
}
}

template <typename SrcType, typename DstType>
std::vector<DstType> bitcastValues(const Dtype& src_dtype, const Value& v) {
const std::vector<SrcType>& src_values = v.as_vec<SrcType>();
std::vector<DstType> dst_values(src_values.size());
for (int i = 0; i < src_dtype.lanes(); ++i) {
dst_values[i] = raw_bitcast<DstType>(src_values[i]);
}
return dst_values;
}

template <typename SrcType>
void doBitCastFromSrc(
const Dtype& src_dtype,
const Dtype& dst_dtype,
const Value& v) {
switch (dst_dtype.scalar_type()) {
#define DST_TYPE_CASE(Type, Name) \
case ScalarType::Name: \
this->value_ = Value(bitcastValues<SrcType, Type>(src_dtype, v)); \
break;
// bool/half not supported
AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE);
#undef DST_TYPE_CASE
default:
throw unsupported_dtype();
}
}

TORCH_API void visit(const BitCast* v) override {
const Expr* src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
if (src_dtype.byte_size() != dst_dtype.byte_size()) {
throw malformed_input("lane mismatch in Cast", v);
}
if (src_dtype != dst_dtype) {
switch (src_dtype.scalar_type()) {
#define SRC_TYPE_CASE(Type, Name) \
case ScalarType::Name: \
doBitCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
break;
// bool/half not supported
AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE);
#undef SRC_TYPE_CASE
default:
throw unsupported_dtype();
}
}
}

TORCH_API void visit(const For* v) override {
const Expr* var_node = v->var();
v->start()->accept(this);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tensorexpr/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum IRNodeType {
kCompareSelect,
kLet,
kCast,
kBitCast,
kBroadcast,
kRamp,
kPolynomial,
Expand Down
29 changes: 29 additions & 0 deletions torch/csrc/jit/tensorexpr/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ inline int getPrecedence(IRNodeType ty) {
case kPrimitive:
return 0;
case kCast:
case kBitCast:
return 2;
case kAdd:
case kSub:
Expand Down Expand Up @@ -81,6 +82,34 @@ ExprHandle cast(const ExprHandle& src_value) {
return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
}

// This is a bitwise cast, akin to bitcast in LLVM
class BitCast : public ExprNode<BitCast> {
public:
const Expr* src_value() const {
return src_value_;
}
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
return ExprHandle(new BitCast(dtype, src_value.node()));
}
BitCast(Dtype dtype, const Expr* src_value)
: ExprNodeBase(dtype, kBitCast), src_value_(src_value) {
TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
}

bool isConstant() const override {
return src_value_->isConstant();
}

private:
const Expr* src_value_;
};

template <typename T>
ExprHandle bitcast(const ExprHandle& src_value) {
return BitCast::make(
Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
}

// Represent the expression node for binary operators.
// A CRTP pattern to share common code among the operators.
template <typename Op>
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/tensorexpr/ir_mutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ const Expr* IRMutator::mutate(const Cast* v) {
return new Cast(v->dtype(), src_value_new);
}

const Expr* IRMutator::mutate(const BitCast* v) {
const Expr* src_value = v->src_value();
const Expr* src_value_new = src_value->accept_mutator(this);
if (src_value_new == v->src_value()) {
return v;
}
return new BitCast(v->dtype(), src_value_new);
}

const Expr* IRMutator::mutate(const Var* v) {
return v;
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/tensorexpr/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
#undef IMM_DECLARE

class Cast;
class BitCast;
class Var;
class Buf;
class Ramp;
Expand Down Expand Up @@ -75,6 +76,7 @@ class TORCH_API IRMutator {
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
virtual const Expr* mutate(const Cast* v);
virtual const Expr* mutate(const BitCast* v);
virtual const Expr* mutate(const Var* v);
virtual const Expr* mutate(const Buf* v);
virtual const Expr* mutate(const Ramp* v);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/tensorexpr/ir_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
void IRVisitor::visit(const Cast* v) {
v->src_value()->accept(this);
}
void IRVisitor::visit(const BitCast* v) {
v->src_value()->accept(this);
}
void IRVisitor::visit(const Var* v) {}

void IRVisitor::visit(const Ramp* v) {
Expand Down
Loading

0 comments on commit 6b78644

Please sign in to comment.