Skip to content

Commit

Permalink
use framework::RWLock
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayiFeng committed Aug 31, 2018
1 parent 1f36a4c commit c501826
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 9 deletions.
65 changes: 65 additions & 0 deletions paddle/fluid/framework/rw_lock.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,70 @@ struct RWLock {
};
#endif

class RWLockGuard {
public:
enum Status { kUnLock, kWRLock, kRDLock };

RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
}
}

void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}

void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}

void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
}
}

~RWLockGuard() { UnLock(); }

private:
RWLock* lock_;
Status status_;
};

} // namespace framework
} // namespace paddle
20 changes: 11 additions & 9 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ limitations under the License. */
#include <unordered_set>
#include <vector>

#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include <boost\thread\thread.hpp>
#include "paddle/fluid/framework/rw_lock.h"
#endif

#include "paddle/fluid/memory/memory.h"

namespace paddle {
namespace platform {

Expand Down Expand Up @@ -158,19 +157,22 @@ class CudnnHolder {

void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
boost::upgrade_lock<boost::shared_mutex> shared_lock(mtx_);
framework::RWLockGuard lock_guard(&rw_lock_,
framework::RWLockGuard::Status::kRDLock);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len, &shared_lock);
lock_guard.UnLock();
lock_guard.WRLock();
ReallocateWorkspace(required_workspace_len);
lock_guard.UnLock();
lock_guard.RDLock();
}
cudnn_func(workspace_);
}

~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }

private:
void ReallocateWorkspace(size_t required_workspace_len,
boost::upgrade_lock<boost::shared_mutex>* lock) {
boost::upgrade_to_unique_lock<boost::shared_mutex> unique_lock(*lock);
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
Expand All @@ -192,7 +194,7 @@ class CudnnHolder {
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;

boost::shared_mutex mtx_;
framework::RWLock rw_lock_;
};

CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
Expand Down

0 comments on commit c501826

Please sign in to comment.