forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f0e7e3
commit c43c838
Showing
5 changed files
with
225 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" | ||
|
||
namespace tensorflow { | ||
|
||
EagerOpRewriteRegistry* EagerOpRewriteRegistry::Global() { | ||
static EagerOpRewriteRegistry* global_rewrite_registry = | ||
new EagerOpRewriteRegistry; | ||
return global_rewrite_registry; | ||
} | ||
|
||
void EagerOpRewriteRegistry::Register(Phase phase, | ||
std::unique_ptr<EagerOpRewrite> pass) { | ||
if (rewrites_.find(phase) == rewrites_.end()) { | ||
rewrites_[phase] = std::move(pass); | ||
} else { | ||
TF_CHECK_OK(errors::AlreadyExists( | ||
"An EagerOpRewrite is already registerd for this phase: ", | ||
pass->name())); | ||
} | ||
} | ||
|
||
Status EagerOpRewriteRegistry::RunRewrite( | ||
Phase phase, EagerOperation* orig_op, | ||
std::unique_ptr<tensorflow::EagerOperation>& out_op) { | ||
auto rewrite = rewrites_.find(phase); | ||
if (rewrite != rewrites_.end()) { | ||
Status s = rewrite->second->Run(orig_op, out_op); | ||
if (!s.ok()) return s; | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace tensorflow |
95 changes: 95 additions & 0 deletions
95
tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ | ||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ | ||
|
||
#include <map> | ||
#include <vector> | ||
#include "tensorflow/core/common_runtime/eager/eager_operation.h" | ||
#include "tensorflow/core/lib/core/status.h" | ||
|
||
namespace tensorflow { | ||
|
||
// Eager op rewrites should inherit from this class and | ||
// implement the Run method. | ||
class EagerOpRewrite { | ||
public: | ||
virtual ~EagerOpRewrite() {} | ||
|
||
// To be implemnted by an Eager op rewrite pass. | ||
virtual Status Run(EagerOperation* orig_op, | ||
std::unique_ptr<tensorflow::EagerOperation>& out_op) = 0; | ||
|
||
// Sets the name of the Eager op rewrite. | ||
void set_name(const string& name) { name_ = name; } | ||
|
||
// Returns the name of the Eager op rewrite. | ||
string name() const { return name_; } | ||
|
||
private: | ||
string name_; | ||
}; | ||
|
||
class EagerOpRewriteRegistry { | ||
public: | ||
// Phases at which the Eager op rewrite pass should run. | ||
// For now we only added PRE_EXECUTION. Expand as needed. | ||
enum Phase { | ||
PRE_EXECUTION // right before executing an eager op | ||
}; | ||
|
||
// Add a rewrite pass to the registry. | ||
// Only one rewrite pass is allowed per phase. | ||
void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass); | ||
|
||
// Run the rewrite pass registered for a given phase. | ||
Status RunRewrite(Phase phase, EagerOperation* orig_op, | ||
std::unique_ptr<tensorflow::EagerOperation>& out_op); | ||
|
||
// Returns the global registry of rewrite passes. | ||
static EagerOpRewriteRegistry* Global(); | ||
|
||
private: | ||
// Holds all the registered Eager op rewrites. | ||
std::map<Phase, std::unique_ptr<EagerOpRewrite>> rewrites_; | ||
}; | ||
|
||
namespace eager_rewrite_registration { | ||
|
||
// This class is used to register a new Eager Op rewrite. | ||
class EagerRewriteRegistration { | ||
public: | ||
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase, | ||
std::unique_ptr<EagerOpRewrite> pass, | ||
string rewrite_pass_name) { | ||
pass->set_name(rewrite_pass_name); | ||
EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass)); | ||
} | ||
}; | ||
|
||
} // namespace eager_rewrite_registration | ||
|
||
#define REGISTER_REWRITE(phase, rewrite) \ | ||
REGISTER_REWRITE_UNIQ(__COUNTER__, phase, rewrite) | ||
|
||
#define REGISTER_REWRITE_UNIQ(ctr, phase, rewrite) \ | ||
static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \ | ||
register_rewrite_##ctr( \ | ||
phase, \ | ||
::std::unique_ptr<::tensorflow::EagerOpRewrite>(new rewrite()), \ | ||
#rewrite) | ||
|
||
} // namespace tensorflow | ||
#endif |
54 changes: 54 additions & 0 deletions
54
tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" | ||
#include "tensorflow/core/platform/test.h" | ||
|
||
namespace tensorflow { | ||
|
||
class TestEagerOpRewrite : public EagerOpRewrite { | ||
public: | ||
static int count_; | ||
Status Run(EagerOperation* orig_op, | ||
std::unique_ptr<tensorflow::EagerOperation>& out_op) override { | ||
++count_; | ||
const tensorflow::AttrTypeMap* types; | ||
bool is_function = false; | ||
string kNewOp = "NoOp"; | ||
TF_RETURN_IF_ERROR( | ||
tensorflow::AttrTypeMapForOp(kNewOp.c_str(), &types, &is_function)); | ||
// Create a new NoOp Eager operation. | ||
out_op.reset(new tensorflow::EagerOperation(nullptr, kNewOp.c_str(), | ||
is_function, types)); | ||
return Status::OK(); | ||
} | ||
}; | ||
|
||
int TestEagerOpRewrite::count_ = 0; | ||
|
||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, TestEagerOpRewrite); | ||
|
||
TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) { | ||
EXPECT_EQ(0, TestEagerOpRewrite::count_); | ||
EagerOperation* orig_op = nullptr; | ||
std::unique_ptr<tensorflow::EagerOperation> out_op; | ||
EXPECT_EQ(Status::OK(), | ||
EagerOpRewriteRegistry::Global()->RunRewrite( | ||
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, out_op)); | ||
EXPECT_EQ(1, TestEagerOpRewrite::count_); | ||
EXPECT_EQ("NoOp", out_op->Name()); | ||
} | ||
|
||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters