Skip to content

Commit

Permalink
fuse onnx lstm, codeformat exclude pybind11, fix Tencent#2562
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jan 7, 2021
1 parent f1c19c1 commit 9b949d6
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 30 deletions.
10 changes: 6 additions & 4 deletions codeformat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

# we run clang-format and astyle twice to get stable format output

find src/ tools/ tests/ examples/ benchmark/ python/ -type f -name '*.c' -o -name '*.cpp' -o -name '*.cc' -o -name '*.h' | xargs -i clang-format -i {}
astyle -n -r "benchmark/*.h,*.cpp,*.cc" "src/*.h,*.cpp,*.cc" "tests/*.h,*.cpp,*.cc" "tools/*.h,*.cpp,*.cc" "examples/*.h,*.cpp,*.cc" "python/*.h,*.cpp,*.cc"
find src/ tools/ tests/ examples/ benchmark/ python/ -type f -name '*.c' -o -name '*.cpp' -o -name '*.cc' -o -name '*.h' | grep -v python/pybind11 | xargs -i clang-format -i {}
astyle -n -r "benchmark/*.h,*.cpp,*.cc" "src/*.h,*.cpp,*.cc" "tests/*.h,*.cpp,*.cc" "tools/*.h,*.cpp,*.cc" "examples/*.h,*.cpp,*.cc"
astyle -n -r "python/*.h,*.cpp,*.cc" --exclude=python/pybind11

find src/ tools/ tests/ examples/ benchmark/ python/ -type f -name '*.c' -o -name '*.cpp' -o -name '*.cc' -o -name '*.h' | xargs -i clang-format -i {}
astyle -n -r "benchmark/*.h,*.cpp,*.cc" "src/*.h,*.cpp,*.cc" "tests/*.h,*.cpp,*.cc" "tools/*.h,*.cpp,*.cc" "examples/*.h,*.cpp,*.cc" "python/*.h,*.cpp,*.cc"
find src/ tools/ tests/ examples/ benchmark/ python/ -type f -name '*.c' -o -name '*.cpp' -o -name '*.cc' -o -name '*.h' | grep -v python/pybind11 | xargs -i clang-format -i {}
astyle -n -r "benchmark/*.h,*.cpp,*.cc" "src/*.h,*.cpp,*.cc" "tests/*.h,*.cpp,*.cc" "tools/*.h,*.cpp,*.cc" "examples/*.h,*.cpp,*.cc"
astyle -n -r "python/*.h,*.cpp,*.cc" --exclude=python/pybind11
6 changes: 4 additions & 2 deletions src/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2484,7 +2484,8 @@ VkQueue VulkanDevice::acquire_queue(uint32_t queue_family_index) const
MutexLockGuard lock(d->queue_lock);

std::vector<VkQueue>& queues = queue_family_index == info.compute_queue_family_index() ? d->compute_queues
: queue_family_index == info.graphics_queue_family_index() ? d->graphics_queues : d->transfer_queues;
: queue_family_index == info.graphics_queue_family_index() ? d->graphics_queues
: d->transfer_queues;
for (int i = 0; i < (int)queues.size(); i++)
{
VkQueue queue = queues[i];
Expand Down Expand Up @@ -2512,7 +2513,8 @@ void VulkanDevice::reclaim_queue(uint32_t queue_family_index, VkQueue queue) con
MutexLockGuard lock(d->queue_lock);

std::vector<VkQueue>& queues = queue_family_index == info.compute_queue_family_index() ? d->compute_queues
: queue_family_index == info.graphics_queue_family_index() ? d->graphics_queues : d->transfer_queues;
: queue_family_index == info.graphics_queue_family_index() ? d->graphics_queues
: d->transfer_queues;
for (int i = 0; i < (int)queues.size(); i++)
{
if (!queues[i])
Expand Down
165 changes: 141 additions & 24 deletions tools/onnx/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,129 @@ static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std:
}
}

static void fuse_bilstm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
{
int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++)
{
onnx::NodeProto* node = mutable_graph->mutable_node(i);

// LSTM <= Transpose - LSTM - Transpose - Reshape - Transpose
if (node->op_type() == "Transpose")
{
if (node_reference[node->output(0)] != 1)
continue;

// 1 0 2
std::vector<int> perm = get_node_attr_ai(*node, "perm");
if (perm.size() != 3)
continue;

if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
continue;

if (i + 4 >= node_count)
continue;

onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);

if (node2->op_type() != "LSTM" || node3->op_type() != "Transpose" || node4->op_type() != "Reshape" || node5->op_type() != "Transpose")
continue;

if (node_reference[node2->output(0)] != 1)
continue;

if (node_reference[node3->output(0)] != 1)
continue;

if (node_reference[node4->output(0)] != 1)
continue;

if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
|| node5->input(0) != node4->output(0))
continue;

std::string direction = get_node_attr_s(*node2, "direction");
if (direction != "bidirectional")
continue;

// 0 2 1 3
std::vector<int> perm3 = get_node_attr_ai(*node3, "perm");
if (perm3.size() != 4)
continue;

if (perm3[0] != 0 || perm3[1] != 2 || perm3[2] != 1 || perm3[3] != 3)
continue;

std::vector<int> shape;
if (node4->input_size() == 1)
{
shape = get_node_attr_ai(*node4, "shape");
}
else
{
// skip weight reshape
if (weights.find(node4->input(1)) == weights.end())
continue;

shape = get_node_attr_from_input_ai(weights[node4->input(1)]);
}

// 0 0 -1
if (shape.size() != 3)
continue;

if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
continue;

// 1 0 2
std::vector<int> perm5 = get_node_attr_ai(*node5, "perm");
if (perm5.size() != 3)
continue;

if (perm5[0] != 1 || perm5[1] != 0 || perm5[2] != 2)
continue;

// reduce
node->set_op_type("noop_reducedncnn");
node3->set_op_type("noop_reducedncnn");
node4->set_op_type("noop_reducedncnn");
node5->set_op_type("noop_reducedncnn");

node_reference[node->output(0)] -= 1;
node_reference[node2->output(0)] -= 1;
node_reference[node3->output(0)] -= 1;
node_reference[node4->output(0)] -= 1;
if (node4->input_size() == 2)
{
node_reference[node4->input(1)] -= 1;
}

blob_names.erase(node->output(0));
blob_names.erase(node2->output(0));
blob_names.erase(node3->output(0));
blob_names.erase(node4->output(0));
if (node2->output_size() > 1)
{
for (int j = 1; j < node2->output_size(); j++)
{
blob_names.erase(node2->output(j));
}
}

node2->set_input(0, node->input(0));
node2->clear_output();
node2->add_output(node5->output(0));

reduced_node_count += 4;
i += 4;
}
}
}

int main(int argc, char** argv)
{
const char* onnxpb = argv[1];
Expand Down Expand Up @@ -1901,12 +2024,6 @@ int main(int argc, char** argv)
{
node_reference[input_name] = node_reference[input_name] + 1;
}

if (op == "LSTM")
{
// ignore all optional input blobs
break;
}
}

if (op == "Dropout")
Expand All @@ -1917,14 +2034,6 @@ int main(int argc, char** argv)
continue;
}

if (op == "LSTM")
{
const std::string& output_name = node.output(0);
blob_names.insert(output_name);
node_reference[output_name] = 0;
continue;
}

for (int j = 0; j < (int)node.output_size(); j++)
{
const std::string& output_name = node.output(j);
Expand Down Expand Up @@ -1972,6 +2081,7 @@ int main(int argc, char** argv)
fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_bilstm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);

// reduce common const weight node_reference
for (int i = 0; i < node_count; i++)
Expand Down Expand Up @@ -2049,9 +2159,10 @@ int main(int argc, char** argv)
}
else if (op == "LSTM")
{
node_reference[node.input(1)] -= 1;
node_reference[node.input(2)] -= 1;
node_reference[node.input(3)] -= 1;
for (int j = 1; j < node.input_size(); j++)
{
node_reference[node.input(j)] -= 1;
}
}
else if (op == "MatMul")
{
Expand Down Expand Up @@ -2118,10 +2229,10 @@ int main(int argc, char** argv)
}
}

// for (auto a: node_reference)
// {
// fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
// }
// for (auto a: node_reference)
// {
// fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
// }

// count all weight node with zero reference
int zero_reference_weight_node_count = 0;
Expand Down Expand Up @@ -2149,6 +2260,10 @@ int main(int argc, char** argv)
}
}

// some op may have anonymous input
// LSTM sequence_lens
blob_names.erase("");

// remove node_reference entry with reference equals to one
int split_layer_count = 0;
int splitncnn_blob_count = 0;
Expand Down Expand Up @@ -2293,6 +2408,11 @@ int main(int argc, char** argv)
input_size--;
}

if (input_name.empty())
{
input_size--;
}

// fprintf(stderr, " input = %s\n", input_name.c_str());
}
/*
Expand Down Expand Up @@ -2468,9 +2588,6 @@ int main(int argc, char** argv)
else if (op == "LSTM")
{
fprintf(pp, "%-16s", "LSTM");
// force no output hidden and cell blob
input_size = 1;
output_size = 1;
}
else if (op == "MatMul")
{
Expand Down

0 comments on commit 9b949d6

Please sign in to comment.