Skip to content

Commit

Permalink
thread local pointer storage
Browse files Browse the repository at this point in the history
Summary:
This is not a generic thread local implementation in the sense that it
only takes pointer. But it does support multiple instances per thread
and lets user plugin function to perform cleanup when thread exits or an
instance gets destroyed.

Test Plan: unit test for now

Reviewers: haobo, igor, sdong, dhruba

Reviewed By: igor

CC: leveldb, kailiu

Differential Revision: https://reviews.facebook.net/D16131
  • Loading branch information
Lei Jin committed Feb 26, 2014
1 parent 4209516 commit b2795b7
Show file tree
Hide file tree
Showing 8 changed files with 878 additions and 6 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Removed arena.h from public header files.
* By default, checksums are verified on every read from database
* Added is_manual_compaction to CompactionFilter::Context
* Added "virtual void WaitForJoin() = 0" in class Env

## 2.7.0 (01/28/2014)

Expand Down
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ TESTS = \
version_set_test \
write_batch_test\
deletefile_test \
table_test
table_test \
thread_local_test

TOOLS = \
sst_dump \
Expand Down Expand Up @@ -147,7 +148,7 @@ all: $(LIBRARY) $(PROGRAMS)

dbg: $(LIBRARY) $(PROGRAMS)

# Will also generate shared libraries.
# Will also generate shared libraries.
release:
$(MAKE) clean
OPT="-DNDEBUG -O2" $(MAKE) all -j32
Expand Down Expand Up @@ -276,6 +277,9 @@ redis_test: utilities/redis/redis_lists_test.o $(LIBOBJECTS) $(TESTHARNESS)
histogram_test: util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o$@ $(LDFLAGS) $(COVERAGEFLAGS)

thread_local_test: util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS)

corruption_test: db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS)

Expand Down
12 changes: 8 additions & 4 deletions hdfs/env_hdfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HdfsFatalException : public std::exception {
class HdfsEnv : public Env {

public:
HdfsEnv(const std::string& fsname) : fsname_(fsname) {
explicit HdfsEnv(const std::string& fsname) : fsname_(fsname) {
posixEnv = Env::Default();
fileSys_ = connectToPath(fsname_);
}
Expand Down Expand Up @@ -108,6 +108,8 @@ class HdfsEnv : public Env {
posixEnv->StartThread(function, arg);
}

virtual void WaitForJoin() { posixEnv->WaitForJoin(); }

virtual Status GetTestDirectory(std::string* path) {
return posixEnv->GetTestDirectory(path);
}
Expand Down Expand Up @@ -161,7 +163,7 @@ class HdfsEnv : public Env {
*/
hdfsFS connectToPath(const std::string& uri) {
if (uri.empty()) {
return NULL;
return nullptr;
}
if (uri.find(kProto) != 0) {
// uri doesn't start with hdfs:// -> use default:0, which is special
Expand Down Expand Up @@ -218,10 +220,10 @@ static const Status notsup;
class HdfsEnv : public Env {

public:
HdfsEnv(const std::string& fsname) {
explicit HdfsEnv(const std::string& fsname) {
fprintf(stderr, "You have not build rocksdb with HDFS support\n");
fprintf(stderr, "Please see hdfs/README for details\n");
throw new std::exception();
throw std::exception();
}

virtual ~HdfsEnv() {
Expand Down Expand Up @@ -288,6 +290,8 @@ class HdfsEnv : public Env {

virtual void StartThread(void (*function)(void* arg), void* arg) {}

virtual void WaitForJoin() {}

virtual Status GetTestDirectory(std::string* path) {return notsup;}

virtual uint64_t NowMicros() {return 0;}
Expand Down
4 changes: 4 additions & 0 deletions include/rocksdb/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ class Env {
// When "function(arg)" returns, the thread will be destroyed.
virtual void StartThread(void (*function)(void* arg), void* arg) = 0;

// Wait for all threads started by StartThread to terminate.
virtual void WaitForJoin() = 0;

// *path is set to a temporary directory that can be used for testing. It may
// or many not have just been created. The directory may or may not differ
// between runs of the same process, but subsequent calls will return the
Expand Down Expand Up @@ -634,6 +637,7 @@ class EnvWrapper : public Env {
void StartThread(void (*f)(void*), void* a) {
return target_->StartThread(f, a);
}
void WaitForJoin() { return target_->WaitForJoin(); }
virtual Status GetTestDirectory(std::string* path) {
return target_->GetTestDirectory(path);
}
Expand Down
9 changes: 9 additions & 0 deletions util/env_posix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,8 @@ class PosixEnv : public Env {

virtual void StartThread(void (*function)(void* arg), void* arg);

virtual void WaitForJoin();

virtual Status GetTestDirectory(std::string* result) {
const char* env = getenv("TEST_TMPDIR");
if (env && env[0] != '\0') {
Expand Down Expand Up @@ -1511,6 +1513,13 @@ void PosixEnv::StartThread(void (*function)(void* arg), void* arg) {
PthreadCall("unlock", pthread_mutex_unlock(&mu_));
}

void PosixEnv::WaitForJoin() {
for (const auto tid : threads_to_join_) {
pthread_join(tid, nullptr);
}
threads_to_join_.clear();
}

} // namespace

std::string Env::GenerateUniqueId() {
Expand Down
236 changes: 236 additions & 0 deletions util/thread_local.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright (c) 2013, Facebook, Inc. All rights reserved.
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree. An additional grant
// of patent rights can be found in the PATENTS file in the same directory.
//
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.

#include "util/thread_local.h"
#include "util/mutexlock.h"

#if defined(__GNUC__) && __GNUC__ >= 4
#define UNLIKELY(x) (__builtin_expect((x), 0))
#else
#define UNLIKELY(x) (x)
#endif

namespace rocksdb {

std::unique_ptr<ThreadLocalPtr::StaticMeta> ThreadLocalPtr::StaticMeta::inst_;
port::Mutex ThreadLocalPtr::StaticMeta::mutex_;
#if !defined(OS_MACOSX)
__thread ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::tls_ = nullptr;
#endif

ThreadLocalPtr::StaticMeta* ThreadLocalPtr::StaticMeta::Instance() {
if (UNLIKELY(inst_ == nullptr)) {
MutexLock l(&mutex_);
if (inst_ == nullptr) {
inst_.reset(new StaticMeta());
}
}
return inst_.get();
}

void ThreadLocalPtr::StaticMeta::OnThreadExit(void* ptr) {
auto* tls = static_cast<ThreadData*>(ptr);
assert(tls != nullptr);

auto* inst = Instance();
pthread_setspecific(inst->pthread_key_, nullptr);

MutexLock l(&mutex_);
inst->RemoveThreadData(tls);
// Unref stored pointers of current thread from all instances
uint32_t id = 0;
for (auto& e : tls->entries) {
void* raw = e.ptr.load(std::memory_order_relaxed);
if (raw != nullptr) {
auto unref = inst->GetHandler(id);
if (unref != nullptr) {
unref(raw);
}
}
++id;
}
// Delete thread local structure no matter if it is Mac platform
delete tls;
}

ThreadLocalPtr::StaticMeta::StaticMeta() : next_instance_id_(0) {
if (pthread_key_create(&pthread_key_, &OnThreadExit) != 0) {
throw std::runtime_error("pthread_key_create failed");
}
head_.next = &head_;
head_.prev = &head_;
}

void ThreadLocalPtr::StaticMeta::AddThreadData(ThreadLocalPtr::ThreadData* d) {
mutex_.AssertHeld();
d->next = &head_;
d->prev = head_.prev;
head_.prev->next = d;
head_.prev = d;
}

void ThreadLocalPtr::StaticMeta::RemoveThreadData(
ThreadLocalPtr::ThreadData* d) {
mutex_.AssertHeld();
d->next->prev = d->prev;
d->prev->next = d->next;
d->next = d->prev = d;
}

ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::GetThreadLocal() {
#if defined(OS_MACOSX)
// Make this local variable name look like a member variable so that we
// can share all the code below
ThreadData* tls_ =
static_cast<ThreadData*>(pthread_getspecific(Instance()->pthread_key_));
#endif

if (UNLIKELY(tls_ == nullptr)) {
auto* inst = Instance();
tls_ = new ThreadData();
{
// Register it in the global chain, needs to be done before thread exit
// handler registration
MutexLock l(&mutex_);
inst->AddThreadData(tls_);
}
// Even it is not OS_MACOSX, need to register value for pthread_key_ so that
// its exit handler will be triggered.
if (pthread_setspecific(inst->pthread_key_, tls_) != 0) {
{
MutexLock l(&mutex_);
inst->RemoveThreadData(tls_);
}
delete tls_;
throw std::runtime_error("pthread_setspecific failed");
}
}
return tls_;
}

void* ThreadLocalPtr::StaticMeta::Get(uint32_t id) const {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
return nullptr;
}
return tls->entries[id].ptr.load(std::memory_order_relaxed);
}

void ThreadLocalPtr::StaticMeta::Reset(uint32_t id, void* ptr) {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
// Need mutex to protect entries access within ReclaimId
MutexLock l(&mutex_);
tls->entries.resize(id + 1);
}
tls->entries[id].ptr.store(ptr, std::memory_order_relaxed);
}

void* ThreadLocalPtr::StaticMeta::Swap(uint32_t id, void* ptr) {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
// Need mutex to protect entries access within ReclaimId
MutexLock l(&mutex_);
tls->entries.resize(id + 1);
}
return tls->entries[id].ptr.exchange(ptr, std::memory_order_relaxed);
}

void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector<void*>* ptrs) {
MutexLock l(&mutex_);
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
if (id < t->entries.size()) {
void* ptr =
t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed);
if (ptr != nullptr) {
ptrs->push_back(ptr);
}
}
}
}

void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) {
MutexLock l(&mutex_);
handler_map_[id] = handler;
}

UnrefHandler ThreadLocalPtr::StaticMeta::GetHandler(uint32_t id) {
mutex_.AssertHeld();
auto iter = handler_map_.find(id);
if (iter == handler_map_.end()) {
return nullptr;
}
return iter->second;
}

uint32_t ThreadLocalPtr::StaticMeta::GetId() {
MutexLock l(&mutex_);
if (free_instance_ids_.empty()) {
return next_instance_id_++;
}

uint32_t id = free_instance_ids_.back();
free_instance_ids_.pop_back();
return id;
}

uint32_t ThreadLocalPtr::StaticMeta::PeekId() const {
MutexLock l(&mutex_);
if (!free_instance_ids_.empty()) {
return free_instance_ids_.back();
}
return next_instance_id_;
}

void ThreadLocalPtr::StaticMeta::ReclaimId(uint32_t id) {
// This id is not used, go through all thread local data and release
// corresponding value
MutexLock l(&mutex_);
auto unref = GetHandler(id);
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
if (id < t->entries.size()) {
void* ptr =
t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed);
if (ptr != nullptr && unref != nullptr) {
unref(ptr);
}
}
}
handler_map_[id] = nullptr;
free_instance_ids_.push_back(id);
}

ThreadLocalPtr::ThreadLocalPtr(UnrefHandler handler)
: id_(StaticMeta::Instance()->GetId()) {
if (handler != nullptr) {
StaticMeta::Instance()->SetHandler(id_, handler);
}
}

ThreadLocalPtr::~ThreadLocalPtr() {
StaticMeta::Instance()->ReclaimId(id_);
}

void* ThreadLocalPtr::Get() const {
return StaticMeta::Instance()->Get(id_);
}

void ThreadLocalPtr::Reset(void* ptr) {
StaticMeta::Instance()->Reset(id_, ptr);
}

void* ThreadLocalPtr::Swap(void* ptr) {
return StaticMeta::Instance()->Swap(id_, ptr);
}

void ThreadLocalPtr::Scrape(autovector<void*>* ptrs) {
StaticMeta::Instance()->Scrape(id_, ptrs);
}

} // namespace rocksdb
Loading

0 comments on commit b2795b7

Please sign in to comment.