Skip to content

Commit

Permalink
Adding support for Eager op rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud-abuzaina committed Jun 13, 2019
1 parent 2f0e7e3 commit c43c838
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 6 deletions.
18 changes: 18 additions & 0 deletions tensorflow/core/common_runtime/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ cc_library(
":context",
":copy_to_device_node",
":eager_executor",
":eager_op_rewrite_registry",
":eager_operation",
":kernel_and_device",
":tensor_handle",
Expand All @@ -262,6 +263,23 @@ cc_library(
}),
)

cc_library(
name = "eager_op_rewrite_registry",
srcs = ["eager_op_rewrite_registry.cc"],
hdrs = ["eager_op_rewrite_registry.h"],
deps = [":eager_operation"],
)

tf_cc_test(
name = "eager_op_rewrite_registry_test",
srcs = ["eager_op_rewrite_registry_test.cc"],
deps = [
":eager_op_rewrite_registry",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

tf_cuda_library(
name = "attr_builder",
srcs = ["attr_builder.cc"],
Expand Down
47 changes: 47 additions & 0 deletions tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"

namespace tensorflow {

EagerOpRewriteRegistry* EagerOpRewriteRegistry::Global() {
static EagerOpRewriteRegistry* global_rewrite_registry =
new EagerOpRewriteRegistry;
return global_rewrite_registry;
}

void EagerOpRewriteRegistry::Register(Phase phase,
std::unique_ptr<EagerOpRewrite> pass) {
if (rewrites_.find(phase) == rewrites_.end()) {
rewrites_[phase] = std::move(pass);
} else {
TF_CHECK_OK(errors::AlreadyExists(
"An EagerOpRewrite is already registerd for this phase: ",
pass->name()));
}
}

Status EagerOpRewriteRegistry::RunRewrite(
Phase phase, EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>& out_op) {
auto rewrite = rewrites_.find(phase);
if (rewrite != rewrites_.end()) {
Status s = rewrite->second->Run(orig_op, out_op);
if (!s.ok()) return s;
}
return Status::OK();
}

} // namespace tensorflow
95 changes: 95 additions & 0 deletions tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_

#include <map>
#include <vector>
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {

// Eager op rewrites should inherit from this class and
// implement the Run method.
class EagerOpRewrite {
public:
virtual ~EagerOpRewrite() {}

// To be implemnted by an Eager op rewrite pass.
virtual Status Run(EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>& out_op) = 0;

// Sets the name of the Eager op rewrite.
void set_name(const string& name) { name_ = name; }

// Returns the name of the Eager op rewrite.
string name() const { return name_; }

private:
string name_;
};

class EagerOpRewriteRegistry {
public:
// Phases at which the Eager op rewrite pass should run.
// For now we only added PRE_EXECUTION. Expand as needed.
enum Phase {
PRE_EXECUTION // right before executing an eager op
};

// Add a rewrite pass to the registry.
// Only one rewrite pass is allowed per phase.
void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass);

// Run the rewrite pass registered for a given phase.
Status RunRewrite(Phase phase, EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>& out_op);

// Returns the global registry of rewrite passes.
static EagerOpRewriteRegistry* Global();

private:
// Holds all the registered Eager op rewrites.
std::map<Phase, std::unique_ptr<EagerOpRewrite>> rewrites_;
};

namespace eager_rewrite_registration {

// This class is used to register a new Eager Op rewrite.
class EagerRewriteRegistration {
public:
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,
std::unique_ptr<EagerOpRewrite> pass,
string rewrite_pass_name) {
pass->set_name(rewrite_pass_name);
EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass));
}
};

} // namespace eager_rewrite_registration

#define REGISTER_REWRITE(phase, rewrite) \
REGISTER_REWRITE_UNIQ(__COUNTER__, phase, rewrite)

#define REGISTER_REWRITE_UNIQ(ctr, phase, rewrite) \
static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \
register_rewrite_##ctr( \
phase, \
::std::unique_ptr<::tensorflow::EagerOpRewrite>(new rewrite()), \
#rewrite)

} // namespace tensorflow
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {

class TestEagerOpRewrite : public EagerOpRewrite {
public:
static int count_;
Status Run(EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>& out_op) override {
++count_;
const tensorflow::AttrTypeMap* types;
bool is_function = false;
string kNewOp = "NoOp";
TF_RETURN_IF_ERROR(
tensorflow::AttrTypeMapForOp(kNewOp.c_str(), &types, &is_function));
// Create a new NoOp Eager operation.
out_op.reset(new tensorflow::EagerOperation(nullptr, kNewOp.c_str(),
is_function, types));
return Status::OK();
}
};

int TestEagerOpRewrite::count_ = 0;

REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, TestEagerOpRewrite);

TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) {
EXPECT_EQ(0, TestEagerOpRewrite::count_);
EagerOperation* orig_op = nullptr;
std::unique_ptr<tensorflow::EagerOperation> out_op;
EXPECT_EQ(Status::OK(),
EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, out_op));
EXPECT_EQ(1, TestEagerOpRewrite::count_);
EXPECT_EQ("NoOp", out_op->Name());
}

} // namespace tensorflow
17 changes: 11 additions & 6 deletions tensorflow/core/common_runtime/eager/execute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"

namespace tensorflow {

Expand Down Expand Up @@ -885,11 +886,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// device directly.
bool IsPinnableOp(const string& op_type) {
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
"RandomUniform",
"RandomUniformInt",
"RandomStandardNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"RandomUniform", "RandomUniformInt", "RandomStandardNormal",
"StatelessRandomUniform", "StatelessRandomUniformInt",
"StatelessRandomNormal",
});

Expand Down Expand Up @@ -999,7 +997,14 @@ Status EagerExecute(EagerOperation* op,
bool op_is_local = op->EagerContext()->IsLocal(op->Device());

if (op_is_local) {
return EagerLocalExecute(op, retvals, num_retvals);
std::unique_ptr<tensorflow::EagerOperation> out_op;
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, op, out_op));
if (out_op) {
return EagerLocalExecute(out_op.get(), retvals, num_retvals);
} else {
return EagerLocalExecute(op, retvals, num_retvals);
}
}

if (op->EagerContext()->LogDevicePlacement() || VLOG_IS_ON(1)) {
Expand Down

0 comments on commit c43c838

Please sign in to comment.