Skip to content

Commit

Permalink
refactor: still fallback when a trt segment has tuple/list input/output
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Jul 28, 2022
1 parent d479c98 commit 418d1e5
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ inline bool isTensor(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get());
}

inline bool isListOrTuple(torch::jit::Value* val) {
return val->type()->kind() == torch::jit::TypeKind::TupleType || val->type()->kind() == torch::jit::TypeKind::ListType;
}

bool containNonTensorOutputs(torch::jit::Node* n) {
for (auto output : n->outputs()) {
if (!isTensor(output)) {
Expand Down Expand Up @@ -109,22 +105,19 @@ void find_all_fallback_nodes(
auto cur_node = q.front();
q.pop();
// for every node that produces this fallback node's NonTensor input, they should fallback too
// Even collection feature is supported, since TRT List/Tuple output is not supported yet, the nodes
// that produce List/Tuple still cannot be in TRT segment
for (auto input : cur_node->inputs()) {
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
q.push(input->node());
}
}
// for every node that consumes this fallback node's NonTensor output, they should fallback too
// Since collection feature is supported, we can have List/Tuple input for TRT segment, so we only
// fallback the nodes that take inputs which are not Tensor/List/Tuple
for (auto output : cur_node->outputs()) {
if (!isTensor(output) && !isListOrTuple(output)) {
if (!isTensor(output)) {
for (auto use : output->uses()) {
auto node = use.user;
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
if (node->kind() != torch::jit::prim::Constant &&
global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
q.push(node);
}
}
Expand Down

0 comments on commit 418d1e5

Please sign in to comment.