Skip to content

Commit

Permalink
Create experimental type inference lattice.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 378547892
Change-Id: Iff2dc8baa5ed2349885d46623a1320ec2618109f
  • Loading branch information
aselle authored and tensorflower-gardener committed Jun 10, 2021
1 parent 84262cc commit 9486821
Show file tree
Hide file tree
Showing 11 changed files with 512 additions and 22 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,7 @@ tf_cuda_library(
"//tensorflow/core/framework:shape_inference",
"//tensorflow/core/framework:tensor",
"//tensorflow/core/framework:tensor_shape",
"//tensorflow/core/framework/experimental:type_inference",
"//tensorflow/core/platform:env_impl",
"//tensorflow/core/platform/default/build_config:platformlib",
"//tensorflow/core/profiler/lib:annotated_traceme",
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/core/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ filegroup(
"variant_op_registry.h",
"variant_tensor_data.h",
"versions.h",
"//tensorflow/core/framework/experimental:type_inference.h",
"//tensorflow/core/framework/registration:options.h",
"//tensorflow/core/framework/registration:registration.h",
],
Expand Down Expand Up @@ -429,6 +430,7 @@ filegroup(
"thread_factory.h",
"versions.cc",
"versions.h",
"//tensorflow/core/framework/experimental:type_inference.h",
"//tensorflow/core/framework/registration:options.h",
"//tensorflow/core/framework/registration:registration.h",
],
Expand Down Expand Up @@ -987,11 +989,16 @@ cc_library(

cc_library(
name = "full_type_util",
srcs = ["full_type_util.cc"],
hdrs = ["full_type_util.h"],
srcs = [
"full_type_util.cc",
],
hdrs = [
"full_type_util.h",
],
deps = [
":full_type_proto_cc",
":op_def_builder",
"@com_google_absl//absl/container:flat_hash_map",
],
)

Expand All @@ -1004,6 +1011,7 @@ cc_library(
":full_type_util",
":op_def_builder",
":op_def_util",
# "//tensorflow/core/framework/experimental:type_inference",
"//tensorflow/core/framework/registration",
"//tensorflow/core/lib/core:errors",
"//tensorflow/core/lib/core:status",
Expand Down
64 changes: 64 additions & 0 deletions tensorflow/core/framework/experimental/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Experimental features in the TF framework.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
load(
"//tensorflow/core/platform:rules_cc.bzl",
"cc_library",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_kernel_tests_linkstatic",
)

package(
default_visibility = [
"//tensorflow/core:__subpackages__",
],
licenses = ["notice"],
)

cc_library(
name = "type_inference",
srcs = [
"type_inference.cc",
],
hdrs = ["type_inference.h"],
deps = [
"//tensorflow/core/framework:full_type_proto_cc",
"//tensorflow/core/framework:full_type_util",
"@com_google_absl//absl/container:flat_hash_map",
],
)

tf_cc_test(
name = "type_inference_test",
srcs = [
"type_inference.h",
"type_inference_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
visibility = [
"//tensorflow:internal",
"//tensorflow/core:__pkg__",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

exports_files(
srcs = [
"type_inference.h",
],
visibility = [
"//tensorflow/core/framework:__pkg__",
],
)
210 changes: 210 additions & 0 deletions tensorflow/core/framework/experimental/type_inference.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/experimental/type_inference.h"

#include <iterator>

#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/full_type_util.h"

namespace tensorflow {

namespace full_type {

using Lattice = absl::flat_hash_map<Type, std::vector<Type>>;
Lattice MakeLattice();
Lattice MakeLatticeClosure(const Lattice& lattice);

std::string ShortName(Type type) {
switch (type) {
case TFT_BOOL:
return "b";
case TFT_UINT8:
return "u1";
case TFT_UINT16:
return "u2";
case TFT_UINT32:
return "u4";
case TFT_UINT64:
return "u8";
case TFT_INT8:
return "i1";
case TFT_INT16:
return "i2";
case TFT_INT32:
return "i4";
case TFT_INT64:
return "i8";
case TFT_FLOAT:
return "f4";
case TFT_HALF:
return "f2";
case TFT_DOUBLE:
return "f8";
case TFT_COMPLEX64:
return "c4";
case TFT_COMPLEX128:
return "c8";
case TFT_COMPLEX_WEAK:
return "c*";
case TFT_FLOAT_WEAK:
return "f*";
case TFT_INT_WEAK:
return "i*";
case TFT_BOOL_WEAK:
return "b*";
case TFT_BFLOAT16:
return "bf";
}
return "!!";
}

Type Canonical(Type t) {
switch (t) {
case TFT_COMPLEX_WEAK:
return TFT_COMPLEX64;
case TFT_FLOAT_WEAK:
return TFT_FLOAT;
case TFT_INT_WEAK:
return TFT_INT32;
case TFT_BOOL_WEAK:
return TFT_BOOL;
}
return t;
}

std::string Name(Type type) {
switch (type) {
case TFT_COMPLEX_WEAK:
return "TFT_COMPLEX_WEAK";
case TFT_FLOAT_WEAK:
return "TFT_FLOAT_WEAK";
case TFT_INT_WEAK:
return "TFT_INT_WEAK";
case TFT_BOOL_WEAK:
return "TFT_BOOL_WEAK";
case TFT_BFLOAT16:
return "TFT_BFLOAT16";
}
auto* descriptor = FullTypeId_descriptor();
if (const auto* value = descriptor->FindValueByNumber(type)) {
return value->name();
}
return "__ERROR_UNKNOWN__";
}

Lattice MakeLattice() {
Lattice types;
types[TFT_BOOL_WEAK] = {TFT_BOOL};
types[TFT_BOOL] = {TFT_INT_WEAK};
types[TFT_INT_WEAK] = {TFT_INT8, TFT_UINT8};
types[TFT_FLOAT_WEAK] = {TFT_HALF, TFT_BFLOAT16, TFT_COMPLEX_WEAK};
types[TFT_BFLOAT16] = {TFT_FLOAT};
types[TFT_HALF] = {TFT_FLOAT};
types[TFT_COMPLEX_WEAK] = {TFT_COMPLEX64};
types[TFT_COMPLEX64] = {TFT_COMPLEX128};
types[TFT_FLOAT] = {TFT_DOUBLE, TFT_COMPLEX64};
types[TFT_INT8] = {TFT_INT16};
types[TFT_INT16] = {TFT_INT32};
types[TFT_INT32] = {TFT_INT64};
types[TFT_UINT8] = {TFT_INT16, TFT_UINT16};
types[TFT_UINT16] = {TFT_INT32, TFT_UINT32};
types[TFT_UINT32] = {TFT_INT64, TFT_UINT64};
types[TFT_UINT64] = {TFT_FLOAT_WEAK};
types[TFT_INT64] = {TFT_FLOAT_WEAK};
types[TFT_DOUBLE] = {TFT_COMPLEX128};
types[TFT_COMPLEX128] = {};
for (auto& it : types) std::sort(it.second.begin(), it.second.end());
return types;
}

Lattice MakeLatticeClosure(const Lattice& lattice) {
using Set = std::set<Type>;
Lattice result;
for (auto& l : lattice) {
auto type = l.first;
Set current;
current.insert(type);

for (;;) {
Set additions;
for (const auto& i : current) {
const auto& lat = lattice.find(i)->second;
additions.insert(lat.begin(), lat.end());
}
// Check for cycles, crash since the lattice is static data.
CHECK(additions.find(l.first) == additions.end()); // Crash OK
// Check if we actually got any new types.
size_t old_length = current.size();
current.insert(additions.begin(), additions.end());
if (old_length == current.size()) break;
}
result[type] = std::vector<Type>(current.begin(), current.end());
}
return result;
}

Lattice& LatticeSingleton() {
static Lattice* _lattice = new Lattice(MakeLatticeClosure(MakeLattice()));
return *_lattice;
}

Type ReturnType(Type t1, Type t2) {
auto& closure_lattice = LatticeSingleton();
auto it1 = closure_lattice.find(t1);
auto it2 = closure_lattice.find(t2);
// Check if both types are supported by promotion lattices
if (it1 == closure_lattice.end() || it2 == closure_lattice.end()) {
return TFT_ANY; // TODO(aselle): mdan, do we need an error type?
}
std::vector<Type> t1_t2_reachable;
std::set_intersection(it1->second.begin(), it1->second.end(),
it2->second.begin(), it2->second.end(),
std::back_inserter(t1_t2_reachable));
constexpr Type NOT_FOUND = std::numeric_limits<Type>::max();
Type final_type = NOT_FOUND;
for (auto t : t1_t2_reachable) {
// this must exist, by construction.
auto t_reachable_it = closure_lattice.find(t);
if (t_reachable_it->second == t1_t2_reachable) {
if (final_type != NOT_FOUND) {
LOG(ERROR) << "Ambiguous promotion type.";
return TFT_ANY;
}
final_type = t;
}
}
return Canonical(final_type);
}

FullTypeDef ReturnType(FullTypeDef t1, FullTypeDef t2) {
auto ret = FullTypeDef();
if (t1.type_id() != TFT_TENSOR && t2.type_id() != TFT_TENSOR) {
ret.set_type_id(TFT_ANY);
} else {
auto* arg = ret.add_args();
auto id1 = t1.args()[0].type_id(), id2 = t2.args()[0].type_id();
ret.set_type_id(TFT_TENSOR);
arg->set_type_id(static_cast<FullTypeId>(
ReturnType(static_cast<Type>(id1), static_cast<Type>(id2))));
}
return ret;
}

} // namespace full_type

} // namespace tensorflow
48 changes: 48 additions & 0 deletions tensorflow/core/framework/experimental/type_inference.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_INFERENCE_H_
#define TENSORFLOW_CORE_FRAMEWORK_TYPE_INFERENCE_H_

#include <unordered_map>

#include "tensorflow/core/framework/full_type.pb.h"

namespace tensorflow {

namespace full_type {

// A raw integer type for testing.
using Type = size_t;
// A short name for numeric types i.e. tf.float32 --> f4.
std::string ShortName(Type type);
// A long type name.
std::string Name(Type type);
// For testing. This allows passing in extra types that don't exist in FT.
Type ReturnType(Type t1, Type t2);
// Check what type `t1` and `t2` are promotable to, and return it.
FullTypeDef ReturnType(FullTypeDef t1, FullTypeDef t2);

// TODO(aselle): These shouldn't be necessary in the long run.
enum EXTRA_TYPES {
TFT_BOOL_WEAK = 16000,
TFT_FLOAT_WEAK = 16001,
TFT_INT_WEAK = 16002,
TFT_COMPLEX_WEAK = 16003,
};

} // namespace full_type
} // namespace tensorflow

#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_INFERENCE_H_
Loading

0 comments on commit 9486821

Please sign in to comment.