Skip to content

Commit

Permalink
[tf.data] Fix a device placement issue in prefetch_to_device(). (te…
Browse files Browse the repository at this point in the history
…nsorflow#18607)

* [tf.data] Fix a device placement issue in `prefetch_to_device()`.

Previously, the `iterator_get_device()` op was being infeasibly colocated with
both the iterator and placed on the prefetch target device. Move the
construction of that op outside the `with device():` block to fix this.

Also enable the relevant test to run as a CUDA test.

* Import the cuda_py_test rule.
  • Loading branch information
mrry authored Apr 17, 2018
1 parent 48f7e37 commit ee36693
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 3 additions & 4 deletions tensorflow/contrib/data/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test")

py_test(
name = "batch_dataset_op_test",
Expand Down Expand Up @@ -471,12 +471,11 @@ py_test(
],
)

py_test(
cuda_py_test(
name = "prefetching_ops_test",
size = "small",
srcs = ["prefetching_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
additional_deps = [
"//tensorflow/contrib/data/python/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/contrib/data/python/ops/prefetching_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def _prefetch_fn(handle):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))

iterator_device = gen_dataset_ops.iterator_get_device(
self._input_iterator._iterator_resource)

with ops.device(device):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
target_device=gen_dataset_ops.iterator_get_device(
self._input_iterator._iterator_resource),
target_device=iterator_device,
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=shared_name)
Expand Down

0 comments on commit ee36693

Please sign in to comment.