Skip to content

Commit

Permalink
[tf.data] Memory-safe implementation of sharing access to the memory …
Browse files Browse the repository at this point in the history
…cache.

PiperOrigin-RevId: 307736215
Change-Id: If10ef65e6706a106e6bb4fc2d6fe4542bbe056cc
  • Loading branch information
jsimsa authored and tensorflower-gardener committed Apr 22, 2020
1 parent df7e7b1 commit b546b46
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 75 deletions.
125 changes: 64 additions & 61 deletions tensorflow/core/kernels/data/cache_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,10 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
public:
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache)
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
std::shared_ptr<MemoryCache> cache)
: DatasetBase(DatasetContext(ctx)),
input_(input),
cache_(std::move(cache)) {
input_->Ref();
}

Expand All @@ -708,7 +710,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
return absl::make_unique<MemoryIterator>(
MemoryIterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
cache_);
cache_.get());
}

const DataTypeVector& output_dtypes() const override {
Expand Down Expand Up @@ -964,7 +966,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
}; // MemoryIterator

const DatasetBase* const input_;
MemoryCache* const cache_;
const std::shared_ptr<MemoryCache> cache_;
}; // MemoryDatasetBase

// This version of memory dataset has an exclusive ownership of the memory cache
Expand All @@ -973,22 +975,19 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, const ResourceHandle& resource_handle)
: MemoryDatasetBase(ctx, input, cache),
resource_handle_(resource_handle) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
DCHECK(cache_->RefCountIsOne());
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
};
}
MemoryCacheManager* manager, ResourceHandle&& resource_handle)
: MemoryDatasetBase(ctx, input, manager->get()),
manager_(manager),
resource_handle_(std::move(resource_handle)),
resource_mgr_(ctx->resource_manager()) {}

~MemoryDataset() override {
cache_->Unref();
cleanup_();
manager_->Unref();
Status s = resource_mgr_->Delete<MemoryCacheManager>(
resource_handle_.container(), resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}

protected:
Expand All @@ -1005,8 +1004,9 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
}

private:
std::function<void()> cleanup_;
MemoryCacheManager* const manager_; // Owned.
const ResourceHandle resource_handle_;
ResourceMgr* const resource_mgr_; // Not owned.
};

// This version of memory dataset has a shared ownership of the memory cache
Expand All @@ -1016,28 +1016,23 @@ class CacheDatasetOp::MemoryDatasetV2
: public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
MemoryCache* cache, const ResourceHandle& resource_handle)
: MemoryDatasetBase(ctx, input, cache),
resource_handle_(std::move(resource_handle)) {
cleanup_ = [this, mgr = ctx->resource_manager()]() {
if (cache_->RefCountIsOne()) {
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
resource_handle_.name());
if (!s.ok()) {
if (errors::IsNotFound(s)) {
// This is a bening race resulting from concurrent deletion.
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
} else {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
}
};
}
MemoryCacheManager* manager, ResourceHandle&& resource_handle,
bool owns_resource)
: MemoryDatasetBase(ctx, input, manager->get()),
manager_(manager),
owns_resource_(owns_resource),
resource_handle_(std::move(resource_handle)),
resource_mgr_(ctx->resource_manager()) {}

~MemoryDatasetV2() override {
cache_->Unref();
cleanup_();
manager_->Unref();
if (owns_resource_) {
Status s = resource_mgr_->Delete<MemoryCacheManager>(
resource_handle_.container(), resource_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
}
}
}

protected:
Expand All @@ -1058,8 +1053,10 @@ class CacheDatasetOp::MemoryDatasetV2
}

private:
std::function<void()> cleanup_;
MemoryCacheManager* const manager_; // Owned.
const bool owns_resource_;
const ResourceHandle resource_handle_;
ResourceMgr* const resource_mgr_; // Not owned.
};

CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
Expand All @@ -1077,33 +1074,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
resource_id_counter.fetch_add(1));
if (op_version_ == 2) {
MemoryCache* cache = nullptr;
bool owns_resource = false;
MemoryCacheManager* manager = nullptr;
auto handle = HandleFromInput(ctx, 2);
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
handle.container(), handle.name(), &cache);
Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
handle.container(), handle.name(), &manager);
if (errors::IsNotFound(s)) {
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->LookupOrCreate<MemoryCache>(
container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
owns_resource = true;
OP_REQUIRES_OK(
ctx,
ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
container, name, &manager, [](MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}));
handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
// Ownership of cache is transferred onto `MemoryDatasetV2`.
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
// Ownership of manager is transferred onto `MemoryDatasetV2`.
*output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
owns_resource);
} else {
MemoryCache* cache;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
container, name, &cache, [](MemoryCache** cache) {
*cache = new MemoryCache();
return Status::OK();
}));
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
// Ownership of cache is transferred onto `MemoryDataset`.
*output = new MemoryDataset(ctx, input, cache, handle);
MemoryCacheManager* manager;
OP_REQUIRES_OK(
ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
container, name, &manager, [](MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}));
auto handle =
MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
// Ownership of manager is transferred onto `MemoryDataset`.
*output = new MemoryDataset(ctx, input, manager, std::move(handle));
}
} else {
if (op_version_ == 2) {
Expand Down
12 changes: 4 additions & 8 deletions tensorflow/core/kernels/data/cache_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ constexpr char kMemoryCache[] = "MemoryCache";

} // namespace

string MemoryCache::DebugString() const { return kMemoryCache; }
string MemoryCacheManager::DebugString() const { return kMemoryCache; }

void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
mutex_lock l(mu_);
Expand Down Expand Up @@ -65,19 +65,15 @@ size_t MemoryCache::size() {

AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
OpKernelConstruction* ctx)
: AnonymousResourceOp<MemoryCache>(ctx) {}

void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
AnonymousResourceOp<MemoryCache>::Compute(ctx);
}
: AnonymousResourceOp<MemoryCacheManager>(ctx) {}

string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }

Status AnonymousMemoryCacheHandleOp::CreateResource(
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib, MemoryCache** resource) {
*resource = new MemoryCache();
FunctionLibraryRuntime* lib, MemoryCacheManager** manager) {
*manager = new MemoryCacheManager();
return Status::OK();
}

Expand Down
23 changes: 17 additions & 6 deletions tensorflow/core/kernels/data/cache_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ namespace data {
// The expected use is that a single `MemoryWriterIterator` populates the
// cache with dataset elements. Once all elements are cached, the cache can
// be used by one or more `MemoryReaderIterator`s.
class MemoryCache : public ResourceBase {
class MemoryCache {
public:
MemoryCache() = default;

string DebugString() const override;

// Marks the cache as completed.
void Complete(std::vector<std::vector<Tensor>>&& cache);

Expand All @@ -55,19 +53,32 @@ class MemoryCache : public ResourceBase {
std::vector<std::vector<Tensor>> cache_ TF_GUARDED_BY(mu_);
};

// A resource wrapping a shared instance of a memory cache.
class MemoryCacheManager : public ResourceBase {
public:
MemoryCacheManager() : cache_(std::make_shared<MemoryCache>()) {}

string DebugString() const override;

std::shared_ptr<MemoryCache> get() { return cache_; }

private:
std::shared_ptr<MemoryCache> cache_;
};

// Creates an instance of cache resource and transfers ownership to the caller.
class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
class AnonymousMemoryCacheHandleOp
: public AnonymousResourceOp<MemoryCacheManager> {
public:
explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;

private:
string name() override;
Status CreateResource(OpKernelContext* ctx,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
MemoryCache** resource) override;
MemoryCacheManager** manager) override;
};

// Deletes an instance of cache resource.
Expand Down

0 comments on commit b546b46

Please sign in to comment.