Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… dev_mem_analyse
  • Loading branch information
JiayiFeng committed Aug 29, 2018
2 parents bbcf1ad + 9ae55dd commit d2d082d
Show file tree
Hide file tree
Showing 70 changed files with 5,060 additions and 541 deletions.
5 changes: 2 additions & 3 deletions cmake/external/anakin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ ExternalProject_Add(
extern_anakin
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLML_PROJECT}
# Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code.
GIT_REPOSITORY "https://github.com/luotao1/Anakin"
GIT_TAG "211d1fc5d813d70c0c14072f9083cf25f40940ea"
GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin"
GIT_TAG "9424277cf9ae180a14aff09560d3cd60a49c76d2"
PREFIX ${ANAKIN_SOURCE_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DUSE_GPU_PLACE=YES
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/API.spec
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.InferenceTranspiler.__init__
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
Expand Down Expand Up @@ -299,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', '
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
Expand Down Expand Up @@ -334,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.transpiler.InferenceTranspiler.__init__
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)

if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif()

if (NOT WIN32)
Expand Down
14 changes: 9 additions & 5 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)

cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)

cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
273 changes: 273 additions & 0 deletions paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"

namespace paddle {
namespace framework {
namespace ir {

struct Param {
std::string X = "concat_0.tmp_0";
std::string C0 = "cell_init";
std::string H0 = "hidden_init";
std::string AttentionWeight = "attention_fc.w_0";
std::string AttentionBias = "attention_fc.b_0";
std::string AttentionScalar = "attention_output.w_0";
std::string AttentionScalarBias = "attention_output.b_0";
std::string LSTMWeight = "attention_w.new";
std::string LSTMBias = "attention_b.new";
std::string Hidden = "array_to_lod_tensor_0.tmp_0";
std::string Cell = "at.cell.new";
std::string AttentionedX = "at.x.new";
std::string AttentionFCOut = "at.fc.new";
std::string LSTMX = "at.lstmx.new";
std::string LSTMOUT = "at.lstmout.new";
};

void PrepareParameters(Graph* graph, const Param& param);

void FindWhileOp(Graph* graph) {
GraphPatternDetector gpd;
std::unordered_set<int> fused_external_ops(
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});

gpd.mutable_pattern()->NewNode(
[&](Node* n) { return fused_external_ops.count(n->id()); }, "while");

if (!graph->Has(kGraphvizMarkedNodeAttr)) {
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
}
auto& marked_nodes =
graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);

auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while");
auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node);
};
gpd(graph, handle);

Param param;
// Add AttentionLSTM node
OpDesc op_desc;
op_desc.SetType("attention_lstm");

#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN(X);
OP_SET_IN(C0);
OP_SET_IN(H0);
OP_SET_IN(AttentionWeight);
OP_SET_IN(AttentionBias);
OP_SET_IN(AttentionScalar);
OP_SET_IN(AttentionScalarBias);
OP_SET_IN(LSTMWeight);
OP_SET_IN(LSTMBias);

OP_SET_OUT(Hidden);
OP_SET_OUT(Cell);
OP_SET_OUT(AttentionedX);
OP_SET_OUT(AttentionFCOut);
OP_SET_OUT(LSTMX);
OP_SET_OUT(LSTMOUT);
#undef OP_SET_IN
#undef OP_SET_OUT

auto* X = graph->RetriveNode(34);
auto* LSTMOUT = graph->RetriveNode(81);
auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8);

#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);

auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);

LINK_TO(X, lstm_op);
LINK_TO(cell_init, lstm_op);
LINK_TO(hidden_init, lstm_op);
LINK_TO(lstm_op, LSTMOUT);

GraphSafeRemoveNodes(graph, marked_nodes);
}

#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);

void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out);

void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out);

void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);

// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();

#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();

GATE_W(forget);
GATE_W(input);
GATE_W(output);
GATE_W(c);
#undef GATE_W

auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);

auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();

// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));

auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));

PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
lstm_weight_t);
PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
lstm_bias_t);
}

// Prepare parameters
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out) {
int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims();

float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors(
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
W_output_w0.data<float>(), W_cell_w0.data<float>()});
std::array<const float*, 4> tensors1(
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
W_output_w1.data<float>(), W_cell_w1.data<float>()});

for (int row = 0; row < D; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * row + D * col;
const float* src = tensors[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}

for (int row = 0; row < M; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * (D + row) + D * col;
const float* src = tensors1[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
}

void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out) {
std::array<const float*, 4> tensors(
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()});

PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace());
for (size_t i = 0; i < tensors.size(); i++) {
memcpy(out_data + D * i, tensors[i], D * sizeof(float));
}
}

// Parameters

std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern;

FindWhileOp(graph.get());
return graph;
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(attention_lstm_fuse_pass,
paddle::framework::ir::AttentionLSTMFusePass);
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/analysis/dot.h"
#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace inference {
namespace analysis {
size_t Dot::counter = 0;
} // namespace analysis
} // namespace inference
namespace framework {
namespace ir {

class AttentionLSTMFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
Loading

0 comments on commit d2d082d

Please sign in to comment.