Skip to content

Commit

Permalink
[PyTorch Return to one CachingGraphRunner per pt node, Refactor symbo…
Browse files Browse the repository at this point in the history
…l registration (pytorch#3695)

Summary:
Pull Request resolved: pytorch#3695

Sharing multiple fused glow nodes in one CachingGraphRunner means no graph is available for any of the nodes.

Documentation:
Doxygen

Reviewed By: qizzzh

Differential Revision: D18204370

fbshipit-source-id: f3ec57e7f142c45b51b074dc5e8336e7be3895f1
  • Loading branch information
jackm321 authored and facebook-github-bot committed Oct 29, 2019
1 parent 4a09831 commit a82bdc3
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 254 deletions.
3 changes: 2 additions & 1 deletion torch_glow/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ link_directories(${PYTORCH_DIR}/lib)
add_library(PyTorchModelLoader
CachingGraphRunner.cpp
GlowFuser.cpp
FuseKnownPatterns.cpp
GlowIValue.cpp
PyTorchCommon.cpp
FusePrepack.cpp
Registration.cpp
PyTorchModelLoader.cpp)
target_compile_options(PyTorchModelLoader
PRIVATE
Expand Down
110 changes: 22 additions & 88 deletions torch_glow/src/CachingGraphRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ size_t CachingGraphRunner::computeGraphHash(
return hash;
}

namespace {
static std::mutex graphCacheMutex;
}

Expected<CachingGraphRunner::PerGlowGraphInfo *>
CachingGraphRunner::loadImpl(torch::jit::Stack &stack) {
const auto inputs = torch::jit::last(stack, graph_->inputs().size());
Expand All @@ -55,7 +51,6 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack) {

// If we already have a Glow function compiled for this graph with and the
// given inputs then use that.
std::lock_guard<std::mutex> guard(graphCacheMutex);
auto it = perGlowGraphInfoMap_.find(hash);
if (it != perGlowGraphInfoMap_.end()) {
return it->second.get();
Expand All @@ -82,17 +77,27 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack) {

Error CachingGraphRunner::runImpl(const PerGlowGraphInfo &info,
torch::jit::Stack &stack) const {
size_t numInputs = info.inputPlaceholders.size();
size_t numInputs = graph_->inputs().size();
const auto inputs = torch::jit::last(stack, numInputs);

std::unique_ptr<ExecutionContext> ctx = llvm::make_unique<ExecutionContext>();
auto *bindings = ctx->getPlaceholderBindings();

for (size_t i = 0; i < numInputs; ++i) {
glow::Placeholder *ph = info.inputPlaceholders[i];
glow::TypeRef ty = ph->getType();
glow::Tensor t(inputs[i].toTensor().data_ptr(), ty);
bindings->insert(ph, std::move(t));
// We only hold placeholders for tensor inputs so indexing them is different
// than indexing all inputs.
size_t placeholderI = 0;
for (const auto &input : inputs) {
if (input.isTensor()) {
glow::Placeholder *ph = info.inputPlaceholders[placeholderI++];
glow::TypeRef ty = ph->getType();
glow::Tensor t(input.toTensor().data_ptr(), ty);
bindings->insert(ph, std::move(t));
} else if (input.isObject()) {
// Objects are only used for loading attributes at compile time.
continue;
} else {
return MAKE_ERR("Only Tensor and Object IValue inputs are accepted");
}
}

std::vector<at::IValue> outputs;
Expand Down Expand Up @@ -129,93 +134,22 @@ Error CachingGraphRunner::run(torch::jit::Stack &stack) {
return runImpl(*DCHECK_NOTNULL(info), stack);
}

Error CachingGraphRunner::run(const std::string &key,
torch::jit::Stack &stack) {
std::shared_ptr<PerGlowGraphInfo> info;
{
std::lock_guard<std::mutex> guard(graphCacheMutex);
auto it = glowGraphInfoMap.find(key);
if (it == glowGraphInfoMap.end()) {
return MAKE_ERR(
strFormat("Key: %s not found in glowGraphInfoMap!", key.c_str()));
}
info = it->second;
}
return runImpl(*DCHECK_NOTNULL(info.get()), stack);
}

Error CachingGraphRunner::CompileModule(const torch::jit::script::Module &m,
const std::vector<InputMeta> &inputMeta,
const std::string &opname) {
if (hostManager_ == nullptr) {
return MAKE_ERR("Host manager is null!");
}
const std::string name = "glow::" + m.name().qualifiedName();

std::lock_guard<std::mutex> guard(graphCacheMutex);
// Currently only support one method per module
auto it = glowGraphInfoMap.find(name);
if (it != glowGraphInfoMap.end()) {
// Already compiled
return Error::success();
}

auto info = std::make_shared<PerGlowGraphInfo>();
info->functionName = strFormat("PTFunction%s", name.c_str());

auto methods = m.get_methods();
if (methods.size() != 1) {
return MAKE_ERR("Currently only support one method each module!");
}

std::shared_ptr<torch::jit::Graph> g = nullptr;
// XXX: Here we assume the fusion node is a top level node. This constraint
// can be relaxed if needed.
for (auto node : methods[0].function().graph()->nodes()) {
if (node->kind().toQualString() == opname) {
if (!node->hasAttribute(torch::jit::attr::Subgraph)) {
return MAKE_ERR("Fusion node should have a subgraph!");
}
g = node->g(torch::jit::attr::Subgraph);
}
}
if (!g) {
return MAKE_ERR("No fusion node found");
}
std::unique_ptr<Module> glowModule = llvm::make_unique<Module>();
Function *f = glowModule->createFunction(info->functionName);

RETURN_IF_ERR(PyTorchModelLoader::loadJITGraph(
*f, *g, info->inputPlaceholders, info->outputPlaceholders,
getPyTorchLoaderSettings(), {}, inputMeta));

glow::CompilationContext cctx;

RETURN_IF_ERR(hostManager_->addNetwork(std::move(glowModule), cctx));
glowGraphInfoMap[name] = std::move(info);
Error CachingGraphRunner::warmCache(const std::vector<InputMeta> &inputMeta) {
// TODO: implement caching based on input metas. For now this is an
// opmtimization.
return Error::success();
}

CachingGraphRunner *CachingGraphRunner::getCachingGraphRunner() {
static CachingGraphRunner runner = CachingGraphRunner();
return &runner;
}

CachingGraphRunner::CachingGraphRunner()
: hostManager_(glow::getHostManager()) {}

CachingGraphRunner::CachingGraphRunner(torch::jit::Graph *graph,
runtime::HostManager *hostManager)
CachingGraphRunner::CachingGraphRunner(
std::shared_ptr<torch::jit::Graph> graph,
std::shared_ptr<runtime::HostManager> hostManager)
: graph_(graph), hostManager_(hostManager) {}

CachingGraphRunner::~CachingGraphRunner() {
// Remove Glow functions saved in HostManager when being destroyed.
for (auto &kv : perGlowGraphInfoMap_) {
ERR_TO_BOOL(hostManager_->removeNetwork(kv.second->functionName));
}
for (auto &kv : glowGraphInfoMap) {
ERR_TO_BOOL(hostManager_->removeNetwork(kv.second->functionName));
}
}

} // namespace glow
24 changes: 7 additions & 17 deletions torch_glow/src/CachingGraphRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "PyTorchModelLoader.h"
#include "glow/Runtime/HostManager/HostManager.h"

#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/ir.h>

#include <torch/csrc/jit/import.h>
Expand All @@ -44,15 +43,10 @@ class CachingGraphRunner {

/// The PyTorch JIT Graph that this CachingGraphRunner caches Glow functions
/// for.
torch::jit::Graph *graph_ = nullptr;
std::shared_ptr<torch::jit::Graph> graph_;

/// The HostManager used to store and run Glow graphs.
runtime::HostManager *hostManager_ = nullptr;

// Mapping from module name to PerGlowGraphInfo. Here we assume one method
// each module, which should be the common case for accelerator modules.
std::unordered_map<std::string, std::shared_ptr<PerGlowGraphInfo>>
glowGraphInfoMap;
std::shared_ptr<runtime::HostManager> hostManager_;

/// Mapping from hash of PyTorch inputs to PerGlowGraphInfo for the Glow
/// function that will run inputs matching that hash.
Expand All @@ -76,23 +70,19 @@ class CachingGraphRunner {
size_t computeGraphHash(const c10::ArrayRef<c10::IValue> inputs) const;

public:
CachingGraphRunner(torch::jit::Graph *graph,
runtime::HostManager *hostManager);
CachingGraphRunner();
CachingGraphRunner(std::shared_ptr<torch::jit::Graph> graph,
std::shared_ptr<runtime::HostManager> hostManager);

~CachingGraphRunner();

static CachingGraphRunner *getCachingGraphRunner();

/// Given a PyTorch Stack \p stack of inputs, run he stored PyTorch graph on
/// those inputs. If this is the first time this PyTorch graph has been run
/// with inputs matching the hash of those on the stack then this first loads
/// it as a Glow Function and compiles. \returns error of failure.
Error run(torch::jit::Stack &stack);
Error run(const std::string &key, torch::jit::Stack &stack);
Error CompileModule(const torch::jit::script::Module &module,
const std::vector<InputMeta> &inputMeta,
const std::string &opname);

// Warm up the cache by compiling a Glow function for the inputs in \p stack.
Error warmCache(const std::vector<InputMeta> &inputMeta);
};

} // namespace glow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
* limitations under the License.
*/

#include "FusePrepack.h"
#include "FuseKnownPatterns.h"

#include <glog/logging.h>

#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

namespace glow {

namespace {
/// This pass fuse the quantized::conv_prepack + quantized::conv2d generated by
/// JIT back to quantized::unpacked_conv2d since we dont have
/// quantized::conv_prepack in glow. However regular packed conv's
Expand Down Expand Up @@ -61,4 +67,52 @@ graph(%input, %weights, %bias, %scale, %zero_point):
rewriter.RegisterRewritePattern(beforePattern, afterPattern);
rewriter.runOnGraph(graph);
}

void fuseNumToTensorToNum(std::shared_ptr<torch::jit::Graph> &graph) {
std::string originalPat = R"IR(
graph(%input):
%res1 = prim::NumToTensor(%input)
%res2 = aten::Int(%res1)
return (%res2))IR";

std::string replacementPat = R"IR(
graph(%input):
return (%input))IR";

torch::jit::SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(originalPat, replacementPat);
rewriter.runOnGraph(graph);
}

/// Registers an operator with symbol \p opName but with no implementation.
/// Dummy operators can be used by glow-specific fusion passes prior to loading
/// a glow graph in order to eliminate intermediate values that are unnecessary
/// to Glow such as those created by quantization packing nodes.
void registerDummyOperator(const char *opName) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION);

torch::jit::RegisterOperators op({torch::jit::Operator(
at::Symbol::fromQualString(opName),
[](const torch::jit::Node *node) -> torch::jit::Operation {
LOG(FATAL) << "Operator \"" << (*node)
<< "\" has no implementation and is meant only as a "
"placeholder while fusing ops to run with Glow";
},
options)});
}
} // namespace

void fuseKnownPatterns(std::shared_ptr<torch::jit::Graph> &graph) {
// Register dummy nodes used by custom fusers.
static std::once_flag onceFlag;
std::call_once(onceFlag, []() {
registerDummyOperator("glow::unpacked_quantized_linear");
registerDummyOperator("glow::unpacked_quantized_conv2d");
});

fuseConvPrepack(graph);
fuseLinearPrepack(graph);
fuseNumToTensorToNum(graph);
}
} // namespace glow
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,14 @@
* limitations under the License.
*/

#ifndef GLOW_TORCH_GLOW_SRC_FUSE_PREPACK_H
#define GLOW_TORCH_GLOW_SRC_FUSE_PREPACK_H
#ifndef GLOW_TORCH_GLOW_SRC_FUSE_KNOWN_PATERNS_H
#define GLOW_TORCH_GLOW_SRC_FUSE_KNOWN_PATERNS_H

#include <torch/csrc/jit/ir.h>

namespace glow {
/// Fuse weight packing operation into quantized convolution op thus skipping
/// weight packing.
void fuseConvPrepack(std::shared_ptr<torch::jit::Graph> &graph);

/// Fuse weight packing operation into quantized linear op thus skipping
/// weight packing.
void fuseLinearPrepack(std::shared_ptr<torch::jit::Graph> &graph);
/// Fuse known node patterns in \p graph to assist the PyTorchModelLoader.
void fuseKnownPatterns(std::shared_ptr<torch::jit::Graph> &graph);
} // namespace glow

#endif // GLOW_TORCH_GLOW_SRC_FUSE_PREPACK_H
#endif // GLOW_TORCH_GLOW_SRC_FUSE_KNOWN_PATERNS_H
35 changes: 32 additions & 3 deletions torch_glow/src/GlowFuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@

#include "GlowFuser.h"

#include "FuseKnownPatterns.h"
#include "PyTorchCommon.h"
#include "PyTorchModelLoader.h"
#include "Registration.h"

#include <glog/logging.h>

#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

namespace glow {
namespace {
using isSupportFunc = std::function<bool(torch::jit::Node *)>;

torch::jit::value_list
sortReverseTopological(at::ArrayRef<torch::jit::Value *> inputs,
torch::jit::Block *block) {
Expand Down Expand Up @@ -126,10 +136,9 @@ getNewNode(torch::jit::Node *node, torch::jit::AliasDb &aliasDb,
}
return {++node->reverseIterator(), false};
}
} // namespace

void GlowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, isSupportFunc fn,
at::Symbol kind) {
void fuseJITNodesToGlow(std::shared_ptr<torch::jit::Graph> graph,
isSupportFunc fn, at::Symbol kind) {
torch::jit::AliasDb aliasDb(graph);
auto block = graph->block();

Expand All @@ -147,4 +156,24 @@ void GlowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, isSupportFunc fn,
EliminateDeadCode(graph);
}

} // namespace

void glowCustomFuse(std::shared_ptr<torch::jit::Graph> graph) {
auto symbol = getGlowSymbol();

static std::once_flag onceFlag;
std::call_once(onceFlag, [&symbol]() { registerGlowOp(symbol); });

glowCustomFuse(graph, symbol);
}

void glowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, at::Symbol kind) {
// Prepare the graph by fusing known patterns for the model loader.
// TODO: this should be done only on Glow subgraphs to avoid modifying parts
// of the graph that Glow will not be running.
fuseKnownPatterns(graph);

fuseJITNodesToGlow(graph, PyTorchModelLoader::isNodeSupported, kind);
}

} // namespace glow
Loading

0 comments on commit a82bdc3

Please sign in to comment.