Skip to content

Commit

Permalink
Fixing parallel_stack bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud-abuzaina committed Apr 5, 2019
1 parent a1e3d44 commit c3e2dd1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
5 changes: 1 addition & 4 deletions tensorflow/core/common_runtime/parallel_concat_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {

// Add all the inplace_updates.
std::vector<Node*> control_nodes;
int64 i = 0;
for (const Edge* input_edge : n->in_edges()) {
if (input_edge->IsControlEdge()) {
g->AddControlEdge(input_edge->src(), start);
Expand All @@ -89,13 +88,11 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
Node* update;
TF_RETURN_IF_ERROR(
make_node("_ParallelConcatUpdate")
.Attr("loc", i)
.Attr("loc", input_edge->dst_input())
.Input(start)
.Input(input_edge->src(), input_edge->src_output())
.Finalize(g, &update));
control_nodes.push_back(update);

++i;
}

// Add the final identity.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/kernel_tests/stack_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def testSimple(self):
def testSimpleParallelCPU(self):
np.random.seed(7)
with self.session(use_gpu=False):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2), (100, 24, 24, 3):
data = np.random.randn(*shape).astype(np.float32)
xs = list(map(constant_op.constant, data))
c = array_ops.parallel_stack(xs)
Expand All @@ -70,7 +70,7 @@ def testSimpleParallelCPU(self):
def testSimpleParallelGPU(self):
np.random.seed(7)
with self.session(use_gpu=True):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2), (100, 24, 24, 3):
data = np.random.randn(*shape).astype(np.float32)
xs = list(map(constant_op.constant, data))
c = array_ops.parallel_stack(xs)
Expand Down

0 comments on commit c3e2dd1

Please sign in to comment.