Skip to content

Commit

Permalink
[StaticRuntime] Add TensorExpr fusion with dynamic shapes in SR (pyto…
Browse files Browse the repository at this point in the history
…rch#69475)

Summary:
Pull Request resolved: pytorch#69475

This diff adds TensorExpr fusion with dynamic shapes in SR. This includes tracing the input graph with sample inputs, and then performing fusion with generalization to get fused graphs with dynamic shapes.
ghstack-source-id: 146059043

Test Plan:
```
buck run mode/opt //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```

Reviewed By: d1jang

Differential Revision: D32320088

fbshipit-source-id: 397f498878ddfcee9dad7a839652f79f034fefe3
  • Loading branch information
navahgar authored and facebook-github-bot committed Dec 21, 2021
1 parent c6d1162 commit a6f9531
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
16 changes: 16 additions & 0 deletions torch/csrc/jit/runtime/static/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <torch/csrc/jit/runtime/static/passes.h>
Expand Down Expand Up @@ -319,5 +321,19 @@ void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) {
inlineSmallFusionGroups(block, min_size);
}

void performTensorExprFusion(
std::shared_ptr<Graph> graph,
std::vector<IValue> sample_inputs) {
// Enable TensorExpr fusion with dynamic shapes
setTensorExprDynamicShapeFusionEnabled(true);
GRAPH_DEBUG("Graph before tracing: ", graph);
auto traced_graph = TraceGraph(graph, sample_inputs);
GRAPH_DEBUG("Graph after tracing: ", traced_graph);
FuseTensorExprs(traced_graph);
graph->block()->clear();
graph->block()->cloneFrom(traced_graph->block(), nullptr);
GRAPH_DUMP("Graph after fusion: ", graph);
}

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/runtime/static/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@ TORCH_API void fuseStaticSubgraphs(
std::shared_ptr<Graph> graph,
size_t min_size);

TORCH_API void performTensorExprFusion(
std::shared_ptr<Graph> graph,
std::vector<IValue> sample_inputs);

} // namespace jit
} // namespace torch
32 changes: 25 additions & 7 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/variadic_ops.h>
#include <torch/csrc/jit/runtime/static/fusion.h>
#include <torch/csrc/jit/runtime/static/memory_planner.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <torch/csrc/jit/runtime/static/passes.h>
Expand Down Expand Up @@ -96,8 +97,17 @@ namespace {

void OptimizeGraph(
std::shared_ptr<torch::jit::Graph>& graph,
const StaticModuleOptions& opts) {
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
GRAPH_DUMP("Before optimizations: ", graph);
if (opts.enable_tensorexpr_fusion) {
if (sample_inputs.empty()) {
VLOG(1) << "Cannot perform TensorExpr fusion - sample_inputs is empty";
} else {
VLOG(1) << "Performing TensorExpr fusion";
performTensorExprFusion(graph, std::move(sample_inputs));
}
}
Inline(*graph);
ConstantPropagation(graph);
Canonicalize(graph);
Expand Down Expand Up @@ -135,6 +145,10 @@ void OptimizeGraph(
GRAPH_DUMP("Final graph after optimizations: ", graph);
}

bool IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
return !graph->inputs().empty() && graph->inputs().at(0)->type()->is_module();
}

// remove unused input 0 from graph
bool removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
if (graph->inputs().at(0)->type()->is_module()) {
Expand Down Expand Up @@ -173,7 +187,7 @@ void PrepareGraphForStaticModule(
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
TORCH_CHECK(canEnableStaticRuntime(graph));
OptimizeGraph(graph, opts);
OptimizeGraph(graph, opts, std::move(sample_inputs));
}

std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
Expand All @@ -185,7 +199,8 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
<< opts.cleanup_activations << ", enable_out_variant "
<< opts.enable_out_variant << ", optimize_memory "
<< opts.optimize_memory << ", manage_output_tensors "
<< opts.manage_output_tensors;
<< opts.manage_output_tensors << ", enable_tensorexpr_fusion "
<< opts.enable_tensorexpr_fusion;

Module module = m.copy();
if (!is_frozen) {
Expand All @@ -196,7 +211,10 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
Method method = module.get_method("forward");
auto graph = module.get_method("forward").graph();

PrepareGraphForStaticModule(graph, opts, sample_inputs);
if (!sample_inputs.empty() && IsSelfInGraphInput(graph)) {
sample_inputs.insert(sample_inputs.begin(), m._ivalue());
}
PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));

return std::make_pair(graph, module);
}
Expand All @@ -205,7 +223,7 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
PrepareGraphForStaticModule(graph, opts, sample_inputs);
PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
return std::make_pair(graph, c10::nullopt);
}

Expand Down Expand Up @@ -429,7 +447,7 @@ StaticModule::StaticModule(
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs)
: StaticModule(
PrepareForStaticModule(g->copy(), opts, sample_inputs),
PrepareForStaticModule(g->copy(), opts, std::move(sample_inputs)),
opts) {}

StaticModule::StaticModule(
Expand All @@ -438,7 +456,7 @@ StaticModule::StaticModule(
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs)
: StaticModule(
PrepareForStaticModule(m, is_frozen, opts, sample_inputs),
PrepareForStaticModule(m, is_frozen, opts, std::move(sample_inputs)),
opts) {}

StaticModule::StaticModule(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/static/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ struct TORCH_API StaticModuleOptions {
// graph, where storage is deallocated outside static runtime
// (enable_out_variant must be true)
bool manage_output_tensors{false};
// enable TensorExpr fusion of ops at model loading time
bool enable_tensorexpr_fusion{false};
};

/// The static runime supports two execution modes.
Expand Down

0 comments on commit a6f9531

Please sign in to comment.