Skip to content

Commit

Permalink
[tf.data service] Report infinite cardinality for coordinated read da…
Browse files Browse the repository at this point in the history
…tasets.

PiperOrigin-RevId: 394772677
Change-Id: I4b0bee3b5067c8efaa59d9feee7065610566812c
  • Loading branch information
aaudiber authored and tensorflower-gardener committed Sep 3, 2021
1 parent 7f68c43 commit 42bd7df
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
protocol_(protocol),
data_transfer_protocol_(data_transfer_protocol),
job_name_(job_name),
is_coordinated_read_(consumer_index.has_value()),
consumer_index_(consumer_index),
num_consumers_(num_consumers),
max_outstanding_requests_(max_outstanding_requests),
Expand Down Expand Up @@ -161,6 +162,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}

int64_t Cardinality() const override {
if (is_coordinated_read_) {
// Coordinated reads require the dataset to be infinite.
return kInfiniteCardinality;
}
return kUnknownCardinality;
}

Status CheckExternalState() const override {
return Status(
error::FAILED_PRECONDITION,
Expand Down Expand Up @@ -1051,6 +1060,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
const tstring protocol_;
const tstring data_transfer_protocol_;
const tstring job_name_;
const bool is_coordinated_read_;
const absl::optional<int64_t> consumer_index_;
const absl::optional<int64_t> num_consumers_;
const int64_t max_outstanding_requests_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ def testFiniteV2(self):
"cardinality, but the dataset has cardinality " + str(num_elements)):
self.getDatasetOutput(ds)

# We test only eager combinations because the `map` transformation used for
# compression in make_distributed_dataset makes cardinality unknown in TF1.
@combinations.generate(test_base.v2_only_combinations())
def testCardinality(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
ds = self.make_distributed_dataset(
dataset_ops.Dataset.range(10).repeat(),
cluster,
job_name="test",
consumer_index=0,
num_consumers=2)
self.assertEqual(self.evaluate(ds.cardinality()), dataset_ops.INFINITE)


if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,12 @@ def testNoShardingPolicy(self):
dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF)
self.assertDatasetProduces(dataset, list(range(20)))

@combinations.generate(test_base.default_test_combinations())
def testCardinality(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset = self.make_distributed_range_dataset(10, cluster)
self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.UNKNOWN)


if __name__ == "__main__":
test.main()

0 comments on commit 42bd7df

Please sign in to comment.