Skip to content

Commit

Permalink
Add bfloat16 support for SparseSegmentGrad*
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 367786761
Change-Id: I733728bafbfd9cd24c4235e041f50b099cd234ef
  • Loading branch information
tensorflower-gardener committed Apr 10, 2021
1 parent 78894ad commit 586aab6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/segment_reduction_ops_impl_5.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
SparseSegmentMeanGradOp<type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS

#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
Expand All @@ -111,6 +112,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
SparseSegmentSqrtNGradOp<type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS

#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
Expand Down
50 changes: 41 additions & 9 deletions tensorflow/core/kernels/segment_reduction_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ BM_Reduce_Arg(64, 32, 2);
BM_Reduce_Arg(4096, 32, 2);
BM_Reduce_Arg(4096, 128, 2);

template <DataType T>
static void SparseSegmentMeanGradHelper(::testing::benchmark::State& state,
float uniqueness, int size) {
typedef typename EnumToDataType<T>::Type DT;
Graph* g = new Graph(OpRegistry::Global());
CHECK_LE(uniqueness, 1.0);
CHECK_GT(uniqueness, 0.0);
Expand All @@ -136,36 +138,66 @@ static void SparseSegmentMeanGradHelper(::testing::benchmark::State& state,

const int kDim1 = segments_flat(kNumIndices - 1) + 1;
const int kDim2 = 128;
Tensor input(DT_FLOAT, TensorShape({kDim1, kDim2}));
input.flat<float>().setRandom();
Tensor input(T, TensorShape({kDim1, kDim2}));
input.flat<DT>().setRandom();

Node* node;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseSegmentMeanGrad")
.Input(test::graph::Constant(g, input))
.Input(test::graph::Constant(g, indices))
.Input(test::graph::Constant(g, segments))
.Input(test::graph::Constant(g, output_dim0))
.Attr("T", DT_FLOAT)
.Attr("T", T)
.Finalize(g, &node));

test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
state.SetBytesProcessed(static_cast<int64>(state.iterations()) *
(kDim1 * kDim2) * sizeof(float));
}

static void BM_SparseSegmentMeanGrad_Low(::testing::benchmark::State& state) {
static void BM_SparseSegmentMeanGrad_Low_FP32(
::testing::benchmark::State& state) {
const int size = state.range(0);

return SparseSegmentMeanGradHelper(state, 1.0, size);
return SparseSegmentMeanGradHelper<DT_FLOAT>(state, 1.0, size);
}

static void BM_SparseSegmentMeanGrad_High(::testing::benchmark::State& state) {
static void BM_SparseSegmentMeanGrad_High_FP32(
::testing::benchmark::State& state) {
const int size = state.range(0);

return SparseSegmentMeanGradHelper(state, 0.01, size);
return SparseSegmentMeanGradHelper<DT_FLOAT>(state, 0.01, size);
}

BENCHMARK(BM_SparseSegmentMeanGrad_Low)->UseRealTime()->Arg(1000)->Arg(100000);
BENCHMARK(BM_SparseSegmentMeanGrad_High)->UseRealTime()->Arg(1000)->Arg(100000);
static void BM_SparseSegmentMeanGrad_Low_BF16(
::testing::benchmark::State& state) {
const int size = state.range(0);

return SparseSegmentMeanGradHelper<DT_BFLOAT16>(state, 1.0, size);
}

static void BM_SparseSegmentMeanGrad_High_BF16(
::testing::benchmark::State& state) {
const int size = state.range(0);

return SparseSegmentMeanGradHelper<DT_BFLOAT16>(state, 0.01, size);
}

BENCHMARK(BM_SparseSegmentMeanGrad_Low_FP32)
->UseRealTime()
->Arg(1000)
->Arg(100000);
BENCHMARK(BM_SparseSegmentMeanGrad_High_FP32)
->UseRealTime()
->Arg(1000)
->Arg(100000);
BENCHMARK(BM_SparseSegmentMeanGrad_Low_BF16)
->UseRealTime()
->Arg(1000)
->Arg(100000);
BENCHMARK(BM_SparseSegmentMeanGrad_High_BF16)
->UseRealTime()
->Arg(1000)
->Arg(100000);

} // namespace tensorflow
4 changes: 2 additions & 2 deletions tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ REGISTER_OP("SparseSegmentMeanGrad")
.Input("segment_ids: Tsegmentids")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn);
Expand Down Expand Up @@ -1417,7 +1417,7 @@ REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("segment_ids: Tsegmentids")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn);
Expand Down

0 comments on commit 586aab6

Please sign in to comment.