Skip to content

Commit

Permalink
Add shape validation for tag input of StatsDataset (tensorflow#18462)
Browse files Browse the repository at this point in the history
* Add shape validation for tag input of StatsDataset

The tag field of the StatsDataset needs to be a scalar.
However, there was no check in the shape function. This
fix adds the check of the tag shape.

Signed-off-by: Yong Tang <[email protected]>

* Add test case for tag shape check with BytesProducedStatsDataset

Signed-off-by: Yong Tang <[email protected]>

* Add tag shape check for LatencyStatsDataset

Signed-off-by: Yong Tang <[email protected]>

* Add test case for tag shape check with LatencyStatsDataset

Signed-off-by: Yong Tang <[email protected]>

* Fix pylint issue

Signed-off-by: Yong Tang <[email protected]>
  • Loading branch information
yongtang authored and mrry committed Apr 14, 2018
1 parent b1f83c9 commit d17de3d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ def _build_dataset_bytes_stats(self, num_elements):
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
stats_ops.bytes_produced_stats("bytes_produced"))

def test_bytes_produced_stats_invalid_tag_shape(self):
with self.assertRaisesRegexp(
ValueError, 'Shape must be rank 0 but is rank 1'):
self.run_core_tests(
lambda: dataset_ops.Dataset.range(100).apply(
stats_ops.bytes_produced_stats(["bytes_produced"])),
None, 100)

def testBytesStatsDatasetSaveableCore(self):
num_outputs = 100
self.run_core_tests(
Expand All @@ -235,6 +243,14 @@ def _build_dataset_multiple_tags(self,
return dataset_ops.Dataset.range(num_elements).apply(
stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))

def test_latency_stats_invalid_tag_shape(self):
with self.assertRaisesRegexp(
ValueError, 'Shape must be rank 0 but is rank 1'):
self.run_core_tests(
lambda: dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats(["record_latency", "record_latency_2"])),
None, 100)

def testLatencyStatsDatasetSaveableCore(self):
num_outputs = 100

Expand Down
12 changes: 10 additions & 2 deletions tensorflow/core/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,23 @@ REGISTER_OP("BytesProducedStatsDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle tag_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
return shape_inference::ScalarShape(c);
});

REGISTER_OP("LatencyStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle tag_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
return shape_inference::ScalarShape(c);
});

REGISTER_OP("MapDataset")
.Input("input_dataset: variant")
Expand Down

0 comments on commit d17de3d

Please sign in to comment.