Skip to content

Commit

Permalink
Backed out changeset 4e1241fe65cd (revert a revert :) )
Browse files Browse the repository at this point in the history
Summary:
This fixes the issue but I haven't figured out yet why is it
happening.

Reviewed By: bwasti

Differential Revision: D6437378

fbshipit-source-id: bf983c9b6f57647423423ec6b22e0f9d2b170e74
  • Loading branch information
salexspb authored and facebook-github-bot committed Nov 29, 2017
1 parent 6f218ce commit 913a9a7
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 82 deletions.
55 changes: 34 additions & 21 deletions caffe2/core/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#pragma once

#include <map>
#include <memory>
#include <unordered_set>
#include "caffe2/core/logging.h"

namespace caffe2 {
Expand All @@ -31,12 +31,8 @@ class ObserverBase {
public:
explicit ObserverBase(T* subject) : subject_(subject) {}

virtual bool Start() {
return false;
}
virtual bool Stop() {
return false;
}
virtual void Start() {}
virtual void Stop() {}

virtual std::unique_ptr<ObserverBase<T>> clone() {
LOG(WARNING) << "clone() is not implemented and nullptr will be returned.";
Expand Down Expand Up @@ -69,36 +65,53 @@ class Observable {
/* Returns a reference to the observer after addition. */
const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
const Observer* weak_observer = observer.get();
observers_[weak_observer] = std::move(observer);
return weak_observer;
std::unordered_set<const Observer*> observers;
for (auto& ob : observers_list_) {
observers.insert(ob.get());
}

const auto* observer_ptr = observer.get();
if (observers.count(observer_ptr)) {
return observer_ptr;
}
observers_list_.push_back(std::move(observer));

return observer_ptr;
}

/* Returns a unique_ptr to the observer. */
std::unique_ptr<Observer> DetachObserver(const Observer* observer) {
std::unique_ptr<Observer> strong_observer = std::move(observers_[observer]);
observers_.erase(observer);
return strong_observer;
/**
* Returns a unique_ptr to the removed observer. If not found, return a
* nullptr
*/
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
if (it->get() == observer_ptr) {
auto res = std::move(*it);
observers_list_.erase(it);
return res;
}
}
return nullptr;
}

virtual size_t NumObservers() {
return observers_.size();
return observers_list_.size();
}

void StartAllObservers() {
for (const auto& observer : observers_) {
observer.second->Start();
for (auto& observer : observers_list_) {
observer->Start();
}
}

void StopAllObservers() {
for (const auto& observer : observers_) {
observer.second->Stop();
for (auto& observer : observers_list_) {
observer->Stop();
}
}

protected:
std::map<const Observer*, std::unique_ptr<ObserverBase<T>>> observers_;
std::vector<std::unique_ptr<Observer>> observers_list_;
};

} // namespace caffe2
16 changes: 6 additions & 10 deletions caffe2/core/observer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,34 @@ template <class T>
class DummyObserver final : public ObserverBase<T> {
public:
explicit DummyObserver<T>(T* subject_) : ObserverBase<T>(subject_) {}
bool Start() override;
bool Stop() override;
void Start() override;
void Stop() override;

~DummyObserver() {}
};

template <>
bool DummyObserver<NetBase>::Start() {
void DummyObserver<NetBase>::Start() {
vector<OperatorBase*> operators = subject_->GetOperators();
for (auto& op : operators) {
op->AttachObserver(caffe2::make_unique<DummyObserver<OperatorBase>>(op));
}
counter.fetch_add(1000);
return true;
}

template <>
bool DummyObserver<OperatorBase>::Start() {
void DummyObserver<OperatorBase>::Start() {
counter.fetch_add(100);
return true;
}

template <>
bool DummyObserver<NetBase>::Stop() {
void DummyObserver<NetBase>::Stop() {
counter.fetch_add(10);
return true;
}

template <>
bool DummyObserver<OperatorBase>::Stop() {
void DummyObserver<OperatorBase>::Stop() {
counter.fetch_add(1);
return true;
}

class ObsTestDummyOp final : public OperatorBase {
Expand Down
14 changes: 4 additions & 10 deletions caffe2/observers/runcnt_observer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,18 @@ std::string RunCountNetObserver::debugInfo() {
return "This operator runs " + caffe2::to_string(cnt_) + " times.";
}

bool RunCountNetObserver::Start() {
void RunCountNetObserver::Start() {
const auto& operators = subject_->GetOperators();
for (auto* op : operators) {
op->AttachObserver(caffe2::make_unique<RunCountOperatorObserver>(op, this));
}
return true;
}

bool RunCountNetObserver::Stop() {
return true;
}
void RunCountNetObserver::Stop() {}

bool RunCountOperatorObserver::Start() {
void RunCountOperatorObserver::Start() {
++netObserver_->cnt_;
return true;
}
bool RunCountOperatorObserver::Stop() {
return true;
}
void RunCountOperatorObserver::Stop() {}

} // namespace caffe2
8 changes: 4 additions & 4 deletions caffe2/observers/runcnt_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class RunCountOperatorObserver final : public ObserverBase<OperatorBase> {
std::unique_ptr<ObserverBase<OperatorBase>> clone() override;

private:
bool Start() override;
bool Stop() override;
void Start() override;
void Stop() override;

private:
RunCountNetObserver* netObserver_;
Expand All @@ -34,8 +34,8 @@ class RunCountNetObserver final : public ObserverBase<NetBase> {
friend class RunCountOperatorObserver;

private:
bool Start() override;
bool Stop() override;
void Start() override;
void Stop() override;

protected:
std::atomic<int> cnt_;
Expand Down
12 changes: 4 additions & 8 deletions caffe2/observers/time_observer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,31 @@
namespace caffe2 {

template <>
bool TimeObserverBase<NetBase>::Start() {
void TimeObserverBase<NetBase>::Start() {
CAFFE_THROW(
"This function is overridden by TimeObserver<NetBase>.\
If it was called there is an issue with compilation.");
return false;
}

template <>
bool TimeObserverBase<NetBase>::Stop() {
void TimeObserverBase<NetBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n";
return true;
}

template <>
bool TimeObserverBase<OperatorBase>::Start() {
void TimeObserverBase<OperatorBase>::Start() {
start_time_ = timer_.MilliSeconds();
++iterations_;
return true;
}

template <>
bool TimeObserverBase<OperatorBase>::Stop() {
void TimeObserverBase<OperatorBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This operator iteration took " << current_run
<< " ms to complete.\n";
return true;
}

} // namespace caffe2
7 changes: 3 additions & 4 deletions caffe2/observers/time_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class TimeObserverBase : public ObserverBase<T> {
}
~TimeObserverBase() {}

bool Start() override;
bool Stop() override;
void Start() override;
void Stop() override;

protected:
Timer timer_;
Expand Down Expand Up @@ -77,7 +77,7 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
return sum / subject_->GetOperators().size();
}

bool Start() override {
void Start() override {
for (auto* op : subject_->GetOperators()) {
const auto* observer = op->AttachObserver(
caffe2::make_unique<TimeObserver<OperatorBase>>(op));
Expand All @@ -87,7 +87,6 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
}
start_time_ = timer_.MilliSeconds();
++iterations_;
return true;
}

private:
Expand Down
19 changes: 10 additions & 9 deletions caffe2/operators/recurrent_network_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ class RecurrentNetworkExecutorBase {
void EnsureTimestepInitialized(
int t,
Workspace* ws,
std::map<
const ObserverBase<OperatorBase>*,
std::unique_ptr<ObserverBase<OperatorBase>>>& observers) {
const std::vector<std::unique_ptr<ObserverBase<OperatorBase>>>&
observers_list) {
if (timestep_ops_template_.size() == 0) {
// Firsrt invocation -- compute dependencies
CalculateInternalDependencies();
Expand Down Expand Up @@ -151,12 +150,13 @@ class RecurrentNetworkExecutorBase {
}

rnn_op.op = CreateOperator(op_copy, ws);
for (const auto& observer : observers) {
for (const auto& observer : observers_list) {
std::unique_ptr<ObserverBase<OperatorBase>> observer_copy =
observer.second->clone();
observer->clone();
CAFFE_ENFORCE(
observer_copy,
"Observers without clone() implemented cannot be attached to RNN using RNNExecutor.");
"Observers without clone() implemented cannot be attached "
"to RNN using RNNExecutor.");
rnn_op.op->AttachObserver(std::move(observer_copy));
}
} else {
Expand All @@ -170,12 +170,13 @@ class RecurrentNetworkExecutorBase {
// Otherwise, we need to create a brand new op with the workspace
// owned by this timestep.
rnn_op.op = CreateOperator(step_net_def_.op(rnn_op.order), ws);
for (const auto& observer : observers) {
for (const auto& observer : observers_list) {
std::unique_ptr<ObserverBase<OperatorBase>> observer_copy =
observer.second->clone();
observer->clone();
CAFFE_ENFORCE(
observer_copy,
"Observers without clone() implemented cannot be attached to RNN using RNNExecutor.");
"Observers without clone() implemented cannot be attached "
"to RNN using RNNExecutor.");
rnn_op.op->AttachObserver(std::move(observer_copy));
}
}
Expand Down
6 changes: 3 additions & 3 deletions caffe2/operators/recurrent_network_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class RecurrentNetworkOp final : public Operator<Context> {
}

size_t NumObservers() override {
size_t num = this->observers_.size();
size_t num = this->observers_list_.size();
if (rnnExecutor_) {
num += rnnExecutor_->NumObserversStepNet();
}
Expand Down Expand Up @@ -370,7 +370,7 @@ class RecurrentNetworkOp final : public Operator<Context> {
rnnExecutor_->SetMaxParallelTimesteps(num_workspaces_on_fwd_only);
}
rnnExecutor_->EnsureTimestepInitialized(
t, currentStepWorkspace.get(), this->observers_);
t, currentStepWorkspace.get(), this->observers_list_);
} else {
// Use plain Caffe2 nets
detail::UpdateTimestepBlob(currentStepWorkspace.get(), timestep_, t);
Expand Down Expand Up @@ -763,7 +763,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
for (int32_t t = seqLen - 1; t >= 0; --t) {
if (rnnExecutor_) {
rnnExecutor_->EnsureTimestepInitialized(
t, stepWorkspaces[t].get(), this->observers_);
t, stepWorkspaces[t].get(), this->observers_list_);
} else {
auto* stepNet = stepWorkspaces[t].get()->GetNet(stepNetDef_.name());
if (stepNet == nullptr) {
Expand Down
14 changes: 5 additions & 9 deletions caffe2/share/contrib/observers/perf_observer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PerfNetObserver::PerfNetObserver(NetBase* subject_)

PerfNetObserver::~PerfNetObserver() {}

bool PerfNetObserver::Start() {
void PerfNetObserver::Start() {
static int visitCount = 0;
// Select whether to log the operator or the net.
// We have one sample rate for the entire app.
Expand Down Expand Up @@ -66,12 +66,11 @@ bool PerfNetObserver::Start() {
/* Only start timer when we need to */
timer_.Start();
}
return true;
}

bool PerfNetObserver::Stop() {
void PerfNetObserver::Stop() {
if (logType_ == PerfNetObserver::NONE) {
return true;
return;
}
auto currentRunTime = timer_.MilliSeconds();
std::map<std::string, double> delays;
Expand All @@ -93,7 +92,6 @@ bool PerfNetObserver::Stop() {
observerMap_.clear();
}
ObserverConfig::getReporter()->reportDelay(subject_, delays, "ms");
return true;
}

caffe2::string PerfNetObserver::getObserverName(const OperatorBase* op, int idx)
Expand Down Expand Up @@ -121,21 +119,19 @@ PerfOperatorObserver::PerfOperatorObserver(

PerfOperatorObserver::~PerfOperatorObserver() {}

bool PerfOperatorObserver::Start() {
void PerfOperatorObserver::Start() {
/* Get the time from the start of the net minus the time spent
in previous invocations. It is the time spent on other operators.
This way, when the operator finishes, the time from the start of the net
minus the time spent in all other operators is the total time on this
operator. This is done to avoid saving a timer in each operator */
milliseconds_ = netObserver_->getTimer().MilliSeconds() - milliseconds_;
return true;
}

bool PerfOperatorObserver::Stop() {
void PerfOperatorObserver::Stop() {
/* Time from the start of the net minus the time spent on all other
operators is the time spent on this operator */
milliseconds_ = netObserver_->getTimer().MilliSeconds() - milliseconds_;
return true;
}

double PerfOperatorObserver::getMilliseconds() const {
Expand Down
Loading

0 comments on commit 913a9a7

Please sign in to comment.