forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
106 changed files
with
4,163 additions
and
544 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/c/eager/tape.h" | ||
|
||
namespace tensorflow { | ||
namespace eager { | ||
|
||
bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) { | ||
for (int64 i : tensor_ids) { | ||
if (tensor_tape_.find(i) != tensor_tape_.end()) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
void GradientTape::Watch(int64 tensor_id) { | ||
tensor_tape_.emplace(tensor_id, -1); | ||
} | ||
|
||
void GradientTape::RecordOperation( | ||
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, | ||
gtl::ArraySlice<int64> input_tensor_id, void* backward_function, | ||
const std::function<void()>& backward_function_deleter) { | ||
if (!ShouldRecord(input_tensor_id)) { | ||
backward_function_deleter(); | ||
return; | ||
} | ||
std::vector<int64> ids; | ||
ids.reserve(input_tensor_id.size()); | ||
for (int64 i : input_tensor_id) { | ||
tensor_usage_[i]++; | ||
ids.push_back(i); | ||
} | ||
const int64 op_id = next_op_id_++; | ||
std::vector<TapeTensor> tensors; | ||
tensors.reserve(output_tensors.size()); | ||
for (const TapeTensor& o : output_tensors) { | ||
// Note: the tensor can have already been watched and hence be in the tape, | ||
// so we cannot check that we're inserting it here. | ||
tensor_tape_[o.id] = op_id; | ||
tensor_usage_[o.id] = 1; | ||
tensors.push_back(o); | ||
} | ||
op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, | ||
backward_function_deleter}; | ||
} | ||
|
||
void GradientTape::DeleteTrace(int64 tensor_id) { | ||
auto it = tensor_usage_.find(tensor_id); | ||
if (it == tensor_usage_.end()) { | ||
return; | ||
} | ||
it->second--; | ||
if (it->second != 0) { | ||
return; | ||
} | ||
tensor_usage_.erase(it); | ||
auto tensor_op_it = tensor_tape_.find(tensor_id); | ||
if (tensor_op_it == tensor_tape_.end()) { | ||
return; | ||
} | ||
const int64 op_id = tensor_op_it->second; | ||
if (op_id == -1) { | ||
// Do not delete watched tensors. | ||
return; | ||
} | ||
tensor_tape_.erase(tensor_op_it); | ||
auto op_it = op_tape_.find(op_id); | ||
CHECK(op_it != op_tape_.end()); | ||
for (const auto& output : op_it->second.output_tensor_info) { | ||
if (tensor_usage_.find(output.id) != tensor_usage_.end()) { | ||
// Found a usage for an output, so cannot delete the op. | ||
return; | ||
} | ||
} | ||
for (int64 id : op_it->second.input_tensor_id) { | ||
DeleteTrace(id); | ||
} | ||
op_it->second.backward_function_deleter(); | ||
op_tape_.erase(op_it); | ||
} | ||
|
||
std::pair<TensorTape, OpTape> GradientTape::Export() { | ||
return {std::move(tensor_tape_), std::move(op_tape_)}; | ||
} | ||
|
||
} // namespace eager | ||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#ifndef TENSORFLOW_C_EAGER_TAPE_H_ | ||
#define TENSORFLOW_C_EAGER_TAPE_H_ | ||
|
||
// Language-agnostic gradient tape. Does not perform backpropagation, just | ||
// maintains the data structures required to do so. | ||
|
||
#include <unordered_map> | ||
#include <vector> | ||
#include "tensorflow/core/framework/tensor_shape.h" | ||
#include "tensorflow/core/framework/types.h" | ||
#include "tensorflow/core/lib/gtl/array_slice.h" | ||
#include "tensorflow/core/platform/types.h" | ||
|
||
namespace tensorflow { | ||
namespace eager { | ||
|
||
// Information about a tensor. | ||
struct TapeTensor { | ||
int64 id; // Expected to be unique in the lifetime of this process. | ||
DataType dtype; | ||
TensorShape shape; | ||
}; | ||
|
||
// Represents an entry in the tape. | ||
struct OpTapeEntry { | ||
string op_type; | ||
std::vector<TapeTensor> output_tensor_info; | ||
std::vector<int64> input_tensor_id; | ||
|
||
// TODO(apassos) consider narrowing down this interface. | ||
void* backward_function; | ||
|
||
// Should be called before deleting the backward function. TODO(apassos) use | ||
// unique_ptrs to ensure this happens. | ||
std::function<void()> backward_function_deleter; | ||
}; | ||
|
||
// Map from tensor_id to internally-defined operation-id of the operation which | ||
// produced this tensor. A value of -1 means that the tensor was directly | ||
// watched and not the result of any operation in the tape. | ||
using TensorTape = std::unordered_map<int64, int64>; | ||
|
||
// Map from operation-id to tape entry. | ||
using OpTape = std::unordered_map<int64, OpTapeEntry>; | ||
|
||
// Traces the execution of operations, doing eager garbage collection, and | ||
// exporting a full trace so other code can do backpropagation. Not thread-safe. | ||
class GradientTape { | ||
public: | ||
GradientTape() {} | ||
|
||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids); | ||
|
||
void Watch(int64 tensor_id); | ||
|
||
void RecordOperation(const string& op_type, | ||
gtl::ArraySlice<TapeTensor> output_tensors, | ||
gtl::ArraySlice<int64> input_tensor_id, | ||
void* backward_function, | ||
const std::function<void()>& backward_function_deleter); | ||
|
||
void DeleteTrace(int64 tensor_id); | ||
|
||
// Note: it is only valid to call Export once per tape, and after calling | ||
// export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, | ||
// Record, and Delete have undefined behavior). | ||
std::pair<TensorTape, OpTape> Export(); | ||
|
||
private: | ||
TensorTape tensor_tape_; | ||
OpTape op_tape_; | ||
int64 next_op_id_{0}; | ||
|
||
// Map from tensor id to number of remaining usages (i.e. how many entries in | ||
// the tape refer to it); to aid in tape garbage collection. | ||
std::unordered_map<int64, int64> tensor_usage_; | ||
}; | ||
|
||
} // namespace eager | ||
} // namespace tensorflow | ||
|
||
#endif // TENSORFLOW_C_EAGER_TAPE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.