From b2795b799e3b7ced293c95bfec1d94f5324acb0f Mon Sep 17 00:00:00 2001 From: Lei Jin Date: Tue, 25 Feb 2014 17:47:37 -0800 Subject: [PATCH] thread local pointer storage Summary: This is not a generic thread local implementation in the sense that it only takes pointer. But it does support multiple instances per thread and lets user plugin function to perform cleanup when thread exits or an instance gets destroyed. Test Plan: unit test for now Reviewers: haobo, igor, sdong, dhruba Reviewed By: igor CC: leveldb, kailiu Differential Revision: https://reviews.facebook.net/D16131 --- HISTORY.md | 1 + Makefile | 8 +- hdfs/env_hdfs.h | 12 +- include/rocksdb/env.h | 4 + util/env_posix.cc | 9 + util/thread_local.cc | 236 ++++++++++++++++++++ util/thread_local.h | 158 +++++++++++++ util/thread_local_test.cc | 456 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 878 insertions(+), 6 deletions(-) create mode 100644 util/thread_local.cc create mode 100644 util/thread_local.h create mode 100644 util/thread_local_test.cc diff --git a/HISTORY.md b/HISTORY.md index d48591ac5cc..933c43e4a5b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -7,6 +7,7 @@ * Removed arena.h from public header files. * By default, checksums are verified on every read from database * Added is_manual_compaction to CompactionFilter::Context +* Added "virtual void WaitForJoin() = 0" in class Env ## 2.7.0 (01/28/2014) diff --git a/Makefile b/Makefile index 3be3c3d0815..6eef8bec0bc 100644 --- a/Makefile +++ b/Makefile @@ -87,7 +87,8 @@ TESTS = \ version_set_test \ write_batch_test\ deletefile_test \ - table_test + table_test \ + thread_local_test TOOLS = \ sst_dump \ @@ -147,7 +148,7 @@ all: $(LIBRARY) $(PROGRAMS) dbg: $(LIBRARY) $(PROGRAMS) -# Will also generate shared libraries. +# Will also generate shared libraries. release: $(MAKE) clean OPT="-DNDEBUG -O2" $(MAKE) all -j32 @@ -276,6 +277,9 @@ redis_test: utilities/redis/redis_lists_test.o $(LIBOBJECTS) $(TESTHARNESS) histogram_test: util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(CXX) util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o$@ $(LDFLAGS) $(COVERAGEFLAGS) +thread_local_test: util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) + $(CXX) util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS) + corruption_test: db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(CXX) db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS) diff --git a/hdfs/env_hdfs.h b/hdfs/env_hdfs.h index 886ccdac37a..17d8fcb2bb0 100644 --- a/hdfs/env_hdfs.h +++ b/hdfs/env_hdfs.h @@ -47,7 +47,7 @@ class HdfsFatalException : public std::exception { class HdfsEnv : public Env { public: - HdfsEnv(const std::string& fsname) : fsname_(fsname) { + explicit HdfsEnv(const std::string& fsname) : fsname_(fsname) { posixEnv = Env::Default(); fileSys_ = connectToPath(fsname_); } @@ -108,6 +108,8 @@ class HdfsEnv : public Env { posixEnv->StartThread(function, arg); } + virtual void WaitForJoin() { posixEnv->WaitForJoin(); } + virtual Status GetTestDirectory(std::string* path) { return posixEnv->GetTestDirectory(path); } @@ -161,7 +163,7 @@ class HdfsEnv : public Env { */ hdfsFS connectToPath(const std::string& uri) { if (uri.empty()) { - return NULL; + return nullptr; } if (uri.find(kProto) != 0) { // uri doesn't start with hdfs:// -> use default:0, which is special @@ -218,10 +220,10 @@ static const Status notsup; class HdfsEnv : public Env { public: - HdfsEnv(const std::string& fsname) { + explicit HdfsEnv(const std::string& fsname) { fprintf(stderr, "You have not build rocksdb with HDFS support\n"); fprintf(stderr, "Please see hdfs/README for details\n"); - throw new std::exception(); + throw std::exception(); } virtual ~HdfsEnv() { @@ -288,6 +290,8 @@ class HdfsEnv : public Env { virtual void StartThread(void (*function)(void* arg), void* arg) {} + virtual void WaitForJoin() {} + virtual Status GetTestDirectory(std::string* path) {return notsup;} virtual uint64_t NowMicros() {return 0;} diff --git a/include/rocksdb/env.h b/include/rocksdb/env.h index 06e9b94aa6b..9324250278f 100644 --- a/include/rocksdb/env.h +++ b/include/rocksdb/env.h @@ -205,6 +205,9 @@ class Env { // When "function(arg)" returns, the thread will be destroyed. virtual void StartThread(void (*function)(void* arg), void* arg) = 0; + // Wait for all threads started by StartThread to terminate. + virtual void WaitForJoin() = 0; + // *path is set to a temporary directory that can be used for testing. It may // or many not have just been created. The directory may or may not differ // between runs of the same process, but subsequent calls will return the @@ -634,6 +637,7 @@ class EnvWrapper : public Env { void StartThread(void (*f)(void*), void* a) { return target_->StartThread(f, a); } + void WaitForJoin() { return target_->WaitForJoin(); } virtual Status GetTestDirectory(std::string* path) { return target_->GetTestDirectory(path); } diff --git a/util/env_posix.cc b/util/env_posix.cc index 1ccb32084ec..fcfea28ab43 100644 --- a/util/env_posix.cc +++ b/util/env_posix.cc @@ -1194,6 +1194,8 @@ class PosixEnv : public Env { virtual void StartThread(void (*function)(void* arg), void* arg); + virtual void WaitForJoin(); + virtual Status GetTestDirectory(std::string* result) { const char* env = getenv("TEST_TMPDIR"); if (env && env[0] != '\0') { @@ -1511,6 +1513,13 @@ void PosixEnv::StartThread(void (*function)(void* arg), void* arg) { PthreadCall("unlock", pthread_mutex_unlock(&mu_)); } +void PosixEnv::WaitForJoin() { + for (const auto tid : threads_to_join_) { + pthread_join(tid, nullptr); + } + threads_to_join_.clear(); +} + } // namespace std::string Env::GenerateUniqueId() { diff --git a/util/thread_local.cc b/util/thread_local.cc new file mode 100644 index 00000000000..90571b97e7e --- /dev/null +++ b/util/thread_local.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "util/thread_local.h" +#include "util/mutexlock.h" + +#if defined(__GNUC__) && __GNUC__ >= 4 +#define UNLIKELY(x) (__builtin_expect((x), 0)) +#else +#define UNLIKELY(x) (x) +#endif + +namespace rocksdb { + +std::unique_ptr ThreadLocalPtr::StaticMeta::inst_; +port::Mutex ThreadLocalPtr::StaticMeta::mutex_; +#if !defined(OS_MACOSX) +__thread ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::tls_ = nullptr; +#endif + +ThreadLocalPtr::StaticMeta* ThreadLocalPtr::StaticMeta::Instance() { + if (UNLIKELY(inst_ == nullptr)) { + MutexLock l(&mutex_); + if (inst_ == nullptr) { + inst_.reset(new StaticMeta()); + } + } + return inst_.get(); +} + +void ThreadLocalPtr::StaticMeta::OnThreadExit(void* ptr) { + auto* tls = static_cast(ptr); + assert(tls != nullptr); + + auto* inst = Instance(); + pthread_setspecific(inst->pthread_key_, nullptr); + + MutexLock l(&mutex_); + inst->RemoveThreadData(tls); + // Unref stored pointers of current thread from all instances + uint32_t id = 0; + for (auto& e : tls->entries) { + void* raw = e.ptr.load(std::memory_order_relaxed); + if (raw != nullptr) { + auto unref = inst->GetHandler(id); + if (unref != nullptr) { + unref(raw); + } + } + ++id; + } + // Delete thread local structure no matter if it is Mac platform + delete tls; +} + +ThreadLocalPtr::StaticMeta::StaticMeta() : next_instance_id_(0) { + if (pthread_key_create(&pthread_key_, &OnThreadExit) != 0) { + throw std::runtime_error("pthread_key_create failed"); + } + head_.next = &head_; + head_.prev = &head_; +} + +void ThreadLocalPtr::StaticMeta::AddThreadData(ThreadLocalPtr::ThreadData* d) { + mutex_.AssertHeld(); + d->next = &head_; + d->prev = head_.prev; + head_.prev->next = d; + head_.prev = d; +} + +void ThreadLocalPtr::StaticMeta::RemoveThreadData( + ThreadLocalPtr::ThreadData* d) { + mutex_.AssertHeld(); + d->next->prev = d->prev; + d->prev->next = d->next; + d->next = d->prev = d; +} + +ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::GetThreadLocal() { +#if defined(OS_MACOSX) + // Make this local variable name look like a member variable so that we + // can share all the code below + ThreadData* tls_ = + static_cast(pthread_getspecific(Instance()->pthread_key_)); +#endif + + if (UNLIKELY(tls_ == nullptr)) { + auto* inst = Instance(); + tls_ = new ThreadData(); + { + // Register it in the global chain, needs to be done before thread exit + // handler registration + MutexLock l(&mutex_); + inst->AddThreadData(tls_); + } + // Even it is not OS_MACOSX, need to register value for pthread_key_ so that + // its exit handler will be triggered. + if (pthread_setspecific(inst->pthread_key_, tls_) != 0) { + { + MutexLock l(&mutex_); + inst->RemoveThreadData(tls_); + } + delete tls_; + throw std::runtime_error("pthread_setspecific failed"); + } + } + return tls_; +} + +void* ThreadLocalPtr::StaticMeta::Get(uint32_t id) const { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + return nullptr; + } + return tls->entries[id].ptr.load(std::memory_order_relaxed); +} + +void ThreadLocalPtr::StaticMeta::Reset(uint32_t id, void* ptr) { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + // Need mutex to protect entries access within ReclaimId + MutexLock l(&mutex_); + tls->entries.resize(id + 1); + } + tls->entries[id].ptr.store(ptr, std::memory_order_relaxed); +} + +void* ThreadLocalPtr::StaticMeta::Swap(uint32_t id, void* ptr) { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + // Need mutex to protect entries access within ReclaimId + MutexLock l(&mutex_); + tls->entries.resize(id + 1); + } + return tls->entries[id].ptr.exchange(ptr, std::memory_order_relaxed); +} + +void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector* ptrs) { + MutexLock l(&mutex_); + for (ThreadData* t = head_.next; t != &head_; t = t->next) { + if (id < t->entries.size()) { + void* ptr = + t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed); + if (ptr != nullptr) { + ptrs->push_back(ptr); + } + } + } +} + +void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) { + MutexLock l(&mutex_); + handler_map_[id] = handler; +} + +UnrefHandler ThreadLocalPtr::StaticMeta::GetHandler(uint32_t id) { + mutex_.AssertHeld(); + auto iter = handler_map_.find(id); + if (iter == handler_map_.end()) { + return nullptr; + } + return iter->second; +} + +uint32_t ThreadLocalPtr::StaticMeta::GetId() { + MutexLock l(&mutex_); + if (free_instance_ids_.empty()) { + return next_instance_id_++; + } + + uint32_t id = free_instance_ids_.back(); + free_instance_ids_.pop_back(); + return id; +} + +uint32_t ThreadLocalPtr::StaticMeta::PeekId() const { + MutexLock l(&mutex_); + if (!free_instance_ids_.empty()) { + return free_instance_ids_.back(); + } + return next_instance_id_; +} + +void ThreadLocalPtr::StaticMeta::ReclaimId(uint32_t id) { + // This id is not used, go through all thread local data and release + // corresponding value + MutexLock l(&mutex_); + auto unref = GetHandler(id); + for (ThreadData* t = head_.next; t != &head_; t = t->next) { + if (id < t->entries.size()) { + void* ptr = + t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed); + if (ptr != nullptr && unref != nullptr) { + unref(ptr); + } + } + } + handler_map_[id] = nullptr; + free_instance_ids_.push_back(id); +} + +ThreadLocalPtr::ThreadLocalPtr(UnrefHandler handler) + : id_(StaticMeta::Instance()->GetId()) { + if (handler != nullptr) { + StaticMeta::Instance()->SetHandler(id_, handler); + } +} + +ThreadLocalPtr::~ThreadLocalPtr() { + StaticMeta::Instance()->ReclaimId(id_); +} + +void* ThreadLocalPtr::Get() const { + return StaticMeta::Instance()->Get(id_); +} + +void ThreadLocalPtr::Reset(void* ptr) { + StaticMeta::Instance()->Reset(id_, ptr); +} + +void* ThreadLocalPtr::Swap(void* ptr) { + return StaticMeta::Instance()->Swap(id_, ptr); +} + +void ThreadLocalPtr::Scrape(autovector* ptrs) { + StaticMeta::Instance()->Scrape(id_, ptrs); +} + +} // namespace rocksdb diff --git a/util/thread_local.h b/util/thread_local.h new file mode 100644 index 00000000000..d6fc5f085d3 --- /dev/null +++ b/util/thread_local.h @@ -0,0 +1,158 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#pragma once + +#include +#include +#include +#include + +#include "util/autovector.h" +#include "port/port_posix.h" + +namespace rocksdb { + +// Cleanup function that will be called for a stored thread local +// pointer (if not NULL) when one of the following happens: +// (1) a thread terminates +// (2) a ThreadLocalPtr is destroyed +typedef void (*UnrefHandler)(void* ptr); + +// Thread local storage that only stores value of pointer type. The storage +// distinguish data coming from different thread and different ThreadLocalPtr +// instances. For example, if a regular thread_local variable A is declared +// in DBImpl, two DBImpl objects would share the same A. ThreadLocalPtr avoids +// the confliction. The total storage size equals to # of threads * # of +// ThreadLocalPtr instances. It is not efficient in terms of space, but it +// should serve most of our use cases well and keep code simple. +class ThreadLocalPtr { + public: + explicit ThreadLocalPtr(UnrefHandler handler = nullptr); + + ~ThreadLocalPtr(); + + // Return the current pointer stored in thread local + void* Get() const; + + // Set a new pointer value to the thread local storage. + void Reset(void* ptr); + + // Atomically swap the supplied ptr and return the previous value + void* Swap(void* ptr); + + // Return non-nullptr data for all existing threads and reset them + // to nullptr + void Scrape(autovector* ptrs); + + protected: + struct Entry { + Entry() : ptr(nullptr) {} + Entry(const Entry& e) : ptr(e.ptr.load(std::memory_order_relaxed)) {} + std::atomic ptr; + }; + + // This is the structure that is declared as "thread_local" storage. + // The vector keep list of atomic pointer for all instances for "current" + // thread. The vector is indexed by an Id that is unique in process and + // associated with one ThreadLocalPtr instance. The Id is assigned by a + // global StaticMeta singleton. So if we instantiated 3 ThreadLocalPtr + // instances, each thread will have a ThreadData with a vector of size 3: + // --------------------------------------------------- + // | | instance 1 | instance 2 | instnace 3 | + // --------------------------------------------------- + // | thread 1 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + // | thread 2 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + // | thread 3 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + struct ThreadData { + ThreadData() : entries() {} + std::vector entries; + ThreadData* next; + ThreadData* prev; + }; + + class StaticMeta { + public: + static StaticMeta* Instance(); + + // Return the next available Id + uint32_t GetId(); + // Return the next availabe Id without claiming it + uint32_t PeekId() const; + // Return the given Id back to the free pool. This also triggers + // UnrefHandler for associated pointer value (if not NULL) for all threads. + void ReclaimId(uint32_t id); + + // Return the pointer value for the given id for the current thread. + void* Get(uint32_t id) const; + // Reset the pointer value for the given id for the current thread. + // It triggers UnrefHanlder if the id has existing pointer value. + void Reset(uint32_t id, void* ptr); + // Atomically swap the supplied ptr and return the previous value + void* Swap(uint32_t id, void* ptr); + // Return data for all existing threads and return them to nullptr + void Scrape(uint32_t id, autovector* ptrs); + + // Register the UnrefHandler for id + void SetHandler(uint32_t id, UnrefHandler handler); + + private: + StaticMeta(); + + // Get UnrefHandler for id with acquiring mutex + // REQUIRES: mutex locked + UnrefHandler GetHandler(uint32_t id); + + // Triggered before a thread terminates + static void OnThreadExit(void* ptr); + + // Add current thread's ThreadData to the global chain + // REQUIRES: mutex locked + void AddThreadData(ThreadData* d); + + // Remove current thread's ThreadData from the global chain + // REQUIRES: mutex locked + void RemoveThreadData(ThreadData* d); + + static ThreadData* GetThreadLocal(); + + // Singleton instance + static std::unique_ptr inst_; + + uint32_t next_instance_id_; + // Used to recycle Ids in case ThreadLocalPtr is instantiated and destroyed + // frequently. This also prevents it from blowing up the vector space. + autovector free_instance_ids_; + // Chain all thread local structure together. This is necessary since + // when one ThreadLocalPtr gets destroyed, we need to loop over each + // thread's version of pointer corresponding to that instance and + // call UnrefHandler for it. + ThreadData head_; + + std::unordered_map handler_map_; + + // protect inst, next_instance_id_, free_instance_ids_, head_, + // ThreadData.entries + static port::Mutex mutex_; +#if !defined(OS_MACOSX) + // Thread local storage + static __thread ThreadData* tls_; +#endif + // Used to make thread exit trigger possible if !defined(OS_MACOSX). + // Otherwise, used to retrieve thread data. + pthread_key_t pthread_key_; + }; + + const uint32_t id_; +}; + +} // namespace rocksdb diff --git a/util/thread_local_test.cc b/util/thread_local_test.cc new file mode 100644 index 00000000000..bc7aa5b5259 --- /dev/null +++ b/util/thread_local_test.cc @@ -0,0 +1,456 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. + +#include + +#include "rocksdb/env.h" +#include "port/port_posix.h" +#include "util/autovector.h" +#include "util/thread_local.h" +#include "util/testharness.h" +#include "util/testutil.h" + +namespace rocksdb { + +class ThreadLocalTest { + public: + ThreadLocalTest() : env_(Env::Default()) {} + + Env* env_; +}; + +namespace { + +struct Params { + Params(port::Mutex* m, port::CondVar* c, int* unref, int n, + UnrefHandler handler = nullptr) + : mu(m), + cv(c), + unref(unref), + total(n), + started(0), + completed(0), + doWrite(false), + tls1(handler), + tls2(nullptr) {} + + port::Mutex* mu; + port::CondVar* cv; + int* unref; + int total; + int started; + int completed; + bool doWrite; + ThreadLocalPtr tls1; + ThreadLocalPtr* tls2; +}; + +class IDChecker : public ThreadLocalPtr { + public: + static uint32_t PeekId() { return StaticMeta::Instance()->PeekId(); } +}; + +} // anonymous namespace + +TEST(ThreadLocalTest, UniqueIdTest) { + port::Mutex mu; + port::CondVar cv(&mu); + + ASSERT_EQ(IDChecker::PeekId(), 0); + // New ThreadLocal instance bumps id by 1 + { + // Id used 0 + Params p1(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 1); + // Id used 1 + Params p2(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 2); + // Id used 2 + Params p3(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // Id used 3 + Params p4(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 4); + } + // id 3, 2, 1, 0 are in the free queue in order + ASSERT_EQ(IDChecker::PeekId(), 0); + + // pick up 0 + Params p1(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 1); + // pick up 1 + Params* p2 = new Params(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 2); + // pick up 2 + Params p3(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // return up 1 + delete p2; + ASSERT_EQ(IDChecker::PeekId(), 1); + // Now we have 3, 1 in queue + // pick up 1 + Params p4(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // pick up 3 + Params p5(&mu, &cv, nullptr, 1); + // next new id + ASSERT_EQ(IDChecker::PeekId(), 4); + // After exit, id sequence in queue: + // 3, 1, 2, 0 +} + +TEST(ThreadLocalTest, SequentialReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + ASSERT_EQ(IDChecker::PeekId(), 0); + + port::Mutex mu; + port::CondVar cv(&mu); + Params p(&mu, &cv, nullptr, 1); + ThreadLocalPtr tls2; + p.tls2 = &tls2; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + p.tls1.Reset(reinterpret_cast(1)); + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(1)); + p.tls1.Reset(reinterpret_cast(2)); + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(2)); + + ASSERT_TRUE(p.tls2->Get() == nullptr); + p.tls2->Reset(reinterpret_cast(1)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(1)); + p.tls2->Reset(reinterpret_cast(2)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(2)); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + p.mu->Unlock(); + }; + + for (int iter = 0; iter < 1024; ++iter) { + ASSERT_EQ(IDChecker::PeekId(), 1); + // Another new thread, read/write should not see value from previous thread + env_->StartThread(func, static_cast(&p)); + mu.Lock(); + while (p.completed != iter + 1) { + cv.Wait(); + } + mu.Unlock(); + ASSERT_EQ(IDChecker::PeekId(), 1); + } +} + +TEST(ThreadLocalTest, ConcurrentReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + ASSERT_EQ(IDChecker::PeekId(), 0); + + ThreadLocalPtr tls2; + port::Mutex mu1; + port::CondVar cv1(&mu1); + Params p1(&mu1, &cv1, nullptr, 128); + p1.tls2 = &tls2; + + port::Mutex mu2; + port::CondVar cv2(&mu2); + Params p2(&mu2, &cv2, nullptr, 128); + p2.doWrite = true; + p2.tls2 = &tls2; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + int own = ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + // Let write threads write a different value from the read threads + if (p.doWrite) { + own += 8192; + } + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + auto* env = Env::Default(); + auto start = env->NowMicros(); + + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + // Loop for 1 second + while (env->NowMicros() - start < 1000 * 1000) { + for (int iter = 0; iter < 100000; ++iter) { + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(own)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(own + 1)); + if (p.doWrite) { + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + } + } + } + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + p.mu->Unlock(); + }; + + // Initiate 2 instnaces: one keeps writing and one keeps reading. + // The read instance should not see data from the write instance. + // Each thread local copy of the value are also different from each + // other. + for (int th = 0; th < p1.total; ++th) { + env_->StartThread(func, static_cast(&p1)); + } + for (int th = 0; th < p2.total; ++th) { + env_->StartThread(func, static_cast(&p2)); + } + + mu1.Lock(); + while (p1.completed != p1.total) { + cv1.Wait(); + } + mu1.Unlock(); + + mu2.Lock(); + while (p2.completed != p2.total) { + cv2.Wait(); + } + mu2.Unlock(); + + ASSERT_EQ(IDChecker::PeekId(), 3); +} + +TEST(ThreadLocalTest, Unref) { + ASSERT_EQ(IDChecker::PeekId(), 0); + + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + // Case 0: no unref triggered if ThreadLocalPtr is never accessed + auto func0 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func0, static_cast(&p)); + } + env_->WaitForJoin(); + ASSERT_EQ(unref_count, 0); + } + + // Case 1: unref triggered by thread exit + auto func1 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + ThreadLocalPtr tls2(unref); + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = &tls2; + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func1, static_cast(&p)); + } + + env_->WaitForJoin(); + + // N threads x 2 ThreadLocal instance cleanup on thread exit + ASSERT_EQ(unref_count, 2 * p.total); + } + + // Case 2: unref triggered by ThreadLocal instance destruction + auto func2 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func2, static_cast(&p)); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + // Now destroy one ThreadLocal instance + delete p.tls2; + p.tls2 = nullptr; + // instance destroy for N threads + ASSERT_EQ(unref_count, p.total); + + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + // additional N threads exit unref for the left instance + ASSERT_EQ(unref_count, 2 * p.total); + } +} + +TEST(ThreadLocalTest, Swap) { + ThreadLocalPtr tls; + tls.Reset(reinterpret_cast(1)); + ASSERT_EQ(reinterpret_cast(tls.Swap(nullptr)), 1); + ASSERT_TRUE(tls.Swap(reinterpret_cast(2)) == nullptr); + ASSERT_EQ(reinterpret_cast(tls.Get()), 2); + ASSERT_EQ(reinterpret_cast(tls.Swap(reinterpret_cast(3))), 2); +} + +TEST(ThreadLocalTest, Scrape) { + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func, static_cast(&p)); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + ASSERT_EQ(unref_count, 0); + + // Scrape all thread local data. No unref at thread + // exit or ThreadLocalPtr destruction + autovector ptrs; + p.tls1.Scrape(&ptrs); + p.tls2->Scrape(&ptrs); + delete p.tls2; + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + + ASSERT_EQ(unref_count, 0); + } +} + +} // namespace rocksdb + +int main(int argc, char** argv) { + return rocksdb::test::RunAllTests(); +}