Skip to content

Commit

Permalink
Add additional concat test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 157844113
  • Loading branch information
Justin Lebar authored and tensorflower-gardener committed Jun 2, 2017
1 parent f661128 commit d5421cf
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tensorflow/compiler/xla/tests/concat_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}

XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
ComputationBuilder builder(client_, TestName());

Array3D<float> arr0(9, 17, 1);
arr0.Fill(1);

Array3D<float> arr1(9, 17, 256);
arr1.Fill(2);

Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
for (int64 i = 0; i < expected.n1(); ++i) {
for (int64 j = 0; j < expected.n2(); ++j) {
int64 kk = 0;
for (const Array3D<float>& arr : {arr0, arr1}) {
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
expected(i, j, kk) = arr(i, j, k);
}
}
}
}

ComputationDataHandle h0;
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
&builder, &h0);
ComputationDataHandle h1;
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);

auto concatenated = builder.ConcatInDim({h0, h1}, 2);

ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}

// Describes a binary rank-2 concatenation test.
struct R2BinarySpec {
int64 lhs_dim0;
Expand Down

0 comments on commit d5421cf

Please sign in to comment.