Skip to content

Commit

Permalink
[static runtime] use NNC to generate logit, relu and tanh (pytorch#52322
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#52322

diff BS=1
```
C2 run finished. Milliseconds per iter: 0.0564008. Iters per second: 17730.3
PyTorch run finished. Milliseconds per iter: 0.0677778. Iters per second: 14754.1
```
diff BS=20
```
C2 run finished. Milliseconds per iter: 0.51086. Iters per second: 1957.48
PyTorch run finished. Milliseconds per iter: 0.510077. Iters per second: 1960.49
```

master BS=1
```
C2 run finished. Milliseconds per iter: 0.0567362. Iters per second: 17625.4
PyTorch run finished. Milliseconds per iter: 0.0706478. Iters per second: 14154.7
```

master BS=20
```
C2 run finished. Milliseconds per iter: 0.510943. Iters per second: 1957.17
PyTorch run finished. Milliseconds per iter: 0.516338. Iters per second: 1936.72
```

Reviewed By: bertmaher

Differential Revision: D25407106

fbshipit-source-id: 08595ba5e4be59e2ef95fb9b24da7e7671692395
  • Loading branch information
bwasti authored and facebook-github-bot committed Feb 17, 2021
1 parent 4156588 commit fa393b5
Showing 1 changed file with 178 additions and 19 deletions.
197 changes: 178 additions & 19 deletions torch/csrc/jit/runtime/static/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -266,17 +270,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
at::native::_cat_out_cpu(out_t, in0_tl, in1_i);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::tanh_out(out_t, in0_t);
};
});

// Split out into a function to appease MSVC's pre-processor
SROperator aten_stack(Node* n) {
Expand Down Expand Up @@ -335,30 +328,196 @@ REGISTER_OPERATOR_FUNCTOR(
};
}
});

namespace {

#ifdef TORCH_ENABLE_LLVM

struct TEWrapper {
tensorexpr::KernelArena ka;
tensorexpr::KernelScope ks;
std::unique_ptr<tensorexpr::LLVMCodeGen> cg;
TEWrapper() = default;
void update(std::unique_ptr<tensorexpr::LLVMCodeGen>&& cg_) {
cg = std::move(cg_);
}
template <typename... Ts>
void operator()(const Ts&... ts) {
std::vector<tensorexpr::CodeGen::CallArg> args(
{tensorexpr::CodeGen::CallArg(ts)...});
cg->call(args);
}

inline bool supports(const at::Tensor& t) {
return t.is_contiguous() && t.dtype().Match<float>();
}
};

void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) {
using namespace torch::jit::tensorexpr;
std::vector<For*> loops = ln->getLoopStmtsFor(target);
For *outer, *inner, *tail;
ln->splitWithTail(loops[0], 16 * 8, &outer, &inner, &tail);
ln->vectorize(inner);
ln->splitWithTail(outer, 8, &outer, &inner, &tail);
Stmt* unrolled;
LoopNest::unroll(inner, &unrolled);
}

std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
tensorexpr::Placeholder& in,
tensorexpr::Tensor* out,
tensorexpr::VarHandle& dim) {
using namespace torch::jit::tensorexpr;
LoopNest ln({out});
optimizePointwise(&ln, out);
ln.prepareForCodegen();
Stmt* s = ln.root_stmt();
s = tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(out);
args.emplace_back(in);
args.emplace_back(dim);
auto cg = std::make_unique<LLVMCodeGen>(s, args);
wrap->update(std::move(cg));
return wrap;
};

#else

struct TEWrapper {
TEWrapper() = default;
template <typename... Ts>
void operator()(const Ts&... ts) {
DCHECK(0 && "Invalid call");
}

inline bool supports(const at::Tensor& t) {
return false;
}
};

std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
tensorexpr::Placeholder& in,
tensorexpr::Tensor* out,
tensorexpr::VarHandle& dim) {
return wrap;
};

#endif

} // namespace

std::shared_ptr<TEWrapper> createLogit(c10::optional<float> clamp) {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
if (!clamp) {
return A.load(i);
} else {
auto elem = A.load(i);
auto min = FloatImm::make(*clamp);
auto max = FloatImm::make(1.0f - *clamp);
return ifThenElse(elem < min, min, ifThenElse(elem > max, max, elem));
}
}();
return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem));
});
return wrapTECompute(wrap, A, B, N);
}

std::shared_ptr<TEWrapper> createRelu() {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto zero = FloatImm::make(0.f);
auto a = A.load(i);
return ifThenElse(a < zero, zero, a);
});
return wrapTECompute(wrap, A, B, N);
}

std::shared_ptr<TEWrapper> createTanh() {
using namespace torch::jit::tensorexpr;
auto wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
Placeholder A("A", kFloat, {N});
tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) {
auto a = A.load(i);
return fast_tanh(a);
});
return wrapTECompute(wrap, A, B, N);
}

REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto te = createRelu();
return [te](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::threshold_out(out_t, in0_t, 0, 0);
if (!te->supports(in0_t)) {
fastResizeToZero(out_t);
at::native::threshold_out(out_t, in0_t, 0, 0);
} else {
at::native::resize_as_(out_t, in0_t, c10::nullopt);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});

REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
auto te = createTanh();
return [te](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
if (!te->supports(in0_t)) {
out_t.resize_({0});
at::native::tanh_out(out_t, in0_t);
} else {
out_t.resize_as_(in0_t);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});

REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
c10::optional<float> clamp;
if (n->inputs().size() > 1) {
TORCH_CHECK(n->inputs().at(1)->node()->kind() == prim::Constant);
clamp = toIValue(n->inputs().at(1))->toDouble();
}
auto te = createLogit(clamp);
return [te](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
double in1_d =
p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0;
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::logit_out(out_t, in0_t, in1_d);
if (!te->supports(in0_t)) {
auto in0_t = p_node->Input(0).toTensor();
double in1_d =
p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0;
fastResizeToZero(out_t);
at::native::logit_out(out_t, in0_t, in1_d);
} else {
out_t.resize_as_(in0_t);
(*te)(out_t.data_ptr<float>(), in0_t.data_ptr<float>(), in0_t.numel());
}
};
});

REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
Expand Down

0 comments on commit fa393b5

Please sign in to comment.