Skip to content

Commit

Permalink
Introduce backend extensions (overriding operators on custom backends)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#15153

Reviewed By: gchanan

Differential Revision: D13445571

fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
  • Loading branch information
Roy Li authored and facebook-github-bot committed Feb 1, 2019
1 parent 64186e0 commit 7e642df
Show file tree
Hide file tree
Showing 17 changed files with 317 additions and 13 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ enum class TypeID {
SparseCUDAInt,
SparseCUDALong,
SparseCUDAShort,
MSNPU,
CPUComplexFloat,
CPUComplexDouble,
CUDAComplexFloat,
Expand Down
35 changes: 35 additions & 0 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def TypedDict(name, attrs, total=True): # type: ignore
${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});
""")

# Overrideable stubs to be used in user-extendable backends
TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\
${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});
}
""")

# add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
Expand Down Expand Up @@ -489,6 +496,7 @@ def __getitem__(self, x):
'formals_list': List[AtFormal],
'formals_with_defaults': List[str],
'formals': List[str],
'formals_types': List[str],
'inferred_type': str,
'inplace': bool,
'matches_jit_signature': bool,
Expand All @@ -513,6 +521,8 @@ def __getitem__(self, x):
'return': ReturnDecl,
'returns': List[ReturnType],
'scalar_check': str,
# schema used for extension backend operator registration
'schema': str,
'sparse': bool,
'type_definition_body': List[str],
'type_method_actuals': List[str],
Expand Down Expand Up @@ -1595,3 +1605,28 @@ def process_native(option):
except NYIError:
pass
return type_object_declarations, type_object_definitions


def create_extension_backend(backend_type_env, declarations):
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
type_object_declarations = []
type_object_definitions = []

for declaration in declarations:
for option in declaration['options']:
if not option.get('skip', False):
try:
option['formals_types'] = [f['type'] for f in option['formals_list']]
option['native_actuals'] = [f['name'] for f in option['formals_list']]
schema_args = ", ".join(
["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']])
return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type'])
option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type)
env = nested_dict(option, backend_type_env)
type_object_declarations.append(
TYPE_DERIVED_DECLARATION.substitute(env))
type_object_definitions.append(
TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
except NYIError:
pass
return type_object_declarations, type_object_definitions
53 changes: 52 additions & 1 deletion aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def check_all_files_written(self):
TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h")
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h")
TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp")

LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h")
LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp")
Expand All @@ -141,10 +143,18 @@ def check_all_files_written(self):

NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h")

EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h")

TYPE_REGISTER = CodeTemplate("""\
context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}());
""")

EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\
case Backend::${Backend}:
${Type}Dispatch::register_function(schema, fn);
break;
""")

core_file_manager = FileManager(core_install_dir)
file_manager = FileManager()
cuda_file_manager = FileManager()
Expand All @@ -164,6 +174,7 @@ def check_all_files_written(self):

backends = ['CPU', 'CUDA']
densities = ['Dense', 'Sparse']
extension_backends = ['MSNPU']

# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
Expand Down Expand Up @@ -193,6 +204,8 @@ def check_all_files_written(self):
'function_definitions': [],
'type_ids': [],
'native_function_declarations': [],
'extension_backend_headers': [],
'extension_backend_register_switches': [],
}


Expand Down Expand Up @@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
return env


def generate_type_extension_backend(backend, declarations):
env = {}
env['Type'] = "{}Type".format(backend)
env['Backend'] = backend
env['DeviceType'] = backend
env['is_extension_backend'] = True
env['TypeID'] = 'TypeID::' + backend
top_env['type_ids'].append(backend + ',')

declarations, definitions = function_wrapper.create_extension_backend(
env, declarations)
env['type_method_declarations'] = declarations
env['type_method_definitions'] = definitions

fm = file_manager
fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env)
fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env)

for scalar_name, _, _, _, _ in scalar_types:
type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type'])
top_env['cpu_type_registrations'].append(type_register)
extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env)
top_env['extension_backend_register_switches'].append(extension_backend_register_switch)
top_env['extension_backend_headers'].append(
'#include <ATen/{}.h>'.format(env['Type']))
top_env['cpu_type_headers'].append(
'#include "ATen/{}.h"'.format(env['Type']))

return env


def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations):
assert density != 'Sparse'
scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
Expand Down Expand Up @@ -384,7 +428,7 @@ def declare_outputs():
core_file_manager.will_write(f)
files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h',
'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h',
'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h']
'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h']
for f in files:
file_manager.will_write(f)
cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h']
Expand All @@ -411,6 +455,9 @@ def declare_outputs():
if density != 'Sparse':
fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
for backend in extension_backends:
file_manager.will_write("{}Type.h".format(backend))
file_manager.will_write("{}Type.cpp".format(backend))


def filter_by_extension(files, *extensions):
Expand Down Expand Up @@ -472,6 +519,8 @@ def generate_outputs():
for backend, density, scalar_type in iterate_types():
all_types.append(generate_storage_type_and_tensor(
backend, density, scalar_type, declarations))
for backend in extension_backends:
all_types.append(generate_type_extension_backend(backend, declarations))

all_legacy_th_dispatchers = []
for backend, density, scalar_type in iterate_types():
Expand Down Expand Up @@ -506,6 +555,8 @@ def generate_outputs():

file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env)

file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env)

file_manager.check_all_files_written()
cuda_file_manager.check_all_files_written()

Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/templates/ExtensionBackendRegistration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once
#include <ATen/Backend.h>
${extension_backend_headers}

namespace at {

template <typename FnPtr>
inline void register_extension_backend_op(
Backend backend,
const char * schema,
FnPtr fn) {
switch (backend) {
${extension_backend_register_switches}
default:
AT_ERROR("Invalid extension backend: ", toString(backend));
}
}

} // namespace at
51 changes: 51 additions & 0 deletions aten/src/ATen/templates/TypeExtension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <ATen/${Type}.h>

namespace at {

std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() {
static std::unordered_map<std::string, void *> fn_table;
return fn_table;
}

${Type}::${Type}()
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}

Allocator* ${Type}::allocator() const {
AT_ERROR("allocator is not implemented for ${Type}");
}

Device ${Type}::getDeviceFromPtr(void * data) const {
return DeviceType::${DeviceType};
}

std::unique_ptr<Generator> ${Type}::generator() const {
AT_ERROR("generator is not implemented for ${Type}");
}

ScalarType ${Type}::scalarType() const {
AT_ERROR("scalarType is not implemented for ${Type}");
}

caffe2::TypeMeta ${Type}::typeMeta() const {
AT_ERROR("typeMeta is not implemented for ${Type}");
}

Backend ${Type}::backend() const {
return Backend::${Backend};
}

const char * ${Type}::toString() const {
return "${Type}";
}

TypeID ${Type}::ID() const {
return ${TypeID};
}

size_t ${Type}::elementSizeInBytes() const {
AT_ERROR("elementSizeInBytes is not implemented for ${Type}");
}

${type_method_definitions}

} // namespace at
49 changes: 49 additions & 0 deletions aten/src/ATen/templates/TypeExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once
#include <ATen/TypeDefault.h>

namespace at {

// This dispatch class holds static map in which function pointers are
// registered by schema.
// TODO: Check for invalid schemas prior to registration.
struct CAFFE2_API ${Type}Dispatch {
template<typename FnPtr>
static FnPtr get_function(const std::string& schema) {
auto & fn_table = get_fn_table();
auto it = fn_table.find(schema);
if (it != fn_table.end()) {
return reinterpret_cast<FnPtr>(it->second);
}
AT_ERROR("No function registered for schema: ", schema);
}

template<typename FnPtr>
static void register_function(const std::string& schema, FnPtr fn) {
auto & fn_table = get_fn_table();
if (fn_table.find(schema) != fn_table.end()) {
AT_ERROR("Function already registered for schema: ", schema);
}
fn_table[schema] = reinterpret_cast<void *>(fn);
}

static std::unordered_map<std::string, void *>& get_fn_table();
};

struct CAFFE2_API ${Type} : public TypeDefault {
explicit ${Type}();

Allocator* allocator() const override;
Device getDeviceFromPtr(void * data) const override;
std::unique_ptr<Generator> generator() const override;

virtual ScalarType scalarType() const override;
virtual caffe2::TypeMeta typeMeta() const override;
virtual Backend backend() const override;
virtual const char * toString() const override;
virtual size_t elementSizeInBytes() const override;
virtual TypeID ID() const override;

${type_method_declarations}
};

} // namespace at
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)

list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
Expand Down
66 changes: 66 additions & 0 deletions aten/src/ATen/test/extension_backend_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <gtest/gtest.h>

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ExtensionBackendRegistration.h>

using namespace at;

static int test_int;

Tensor empty_override(IntList size, const TensorOptions & options) {
test_int = 1;
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
MSNPUTensorId(),
false);
return Tensor(std::move(tensor_impl));
}

Tensor empty_like_override(const Tensor & self, const TensorOptions & options) {
test_int = 2;
return self;
}

Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
test_int = 3;
return a;
}

TEST(BackendExtensionTest, TestRegisterOp) {
EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
register_extension_backend_op(
Backend::MSNPU,
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override);
Tensor a = empty({5, 5}, at::kMSNPU);
ASSERT_EQ(a.device().type(), at::kMSNPU);
ASSERT_EQ(a.device().index(), 1);
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
ASSERT_EQ(test_int, 1);

EXPECT_ANY_THROW(empty_like(a, at::kMSNPU));
register_extension_backend_op(
Backend::MSNPU,
"empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override);
Tensor b = empty_like(a, at::kMSNPU);
ASSERT_EQ(test_int, 2);

EXPECT_ANY_THROW(add(a, b));
register_extension_backend_op(
Backend::MSNPU,
"add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
add(a, b);
ASSERT_EQ(test_int, 3);

// Ensure that non-MSNPU operator still works
Tensor d = empty({5, 5}, at::kCPU);
ASSERT_EQ(d.device().type(), at::kCPU);

// Attempt to register on a schema that has already has a function
EXPECT_ANY_THROW(
register_extension_backend_op(
Backend::MSNPU,
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override)
);
}
1 change: 1 addition & 0 deletions aten/tools/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON}
./scalar_tensor_test
./tensor_interop_test
./undefined_tensor_test
./extension_backend_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
Expand Down
Loading

0 comments on commit 7e642df

Please sign in to comment.