forked from snuspl/nimble
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce backend extensions (overriding operators on custom backends)
Summary: Pull Request resolved: pytorch/pytorch#15153 Reviewed By: gchanan Differential Revision: D13445571 fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
- Loading branch information
1 parent
64186e0
commit 7e642df
Showing
17 changed files
with
317 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
#include <ATen/Backend.h> | ||
${extension_backend_headers} | ||
|
||
namespace at { | ||
|
||
template <typename FnPtr> | ||
inline void register_extension_backend_op( | ||
Backend backend, | ||
const char * schema, | ||
FnPtr fn) { | ||
switch (backend) { | ||
${extension_backend_register_switches} | ||
default: | ||
AT_ERROR("Invalid extension backend: ", toString(backend)); | ||
} | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#include <ATen/${Type}.h> | ||
|
||
namespace at { | ||
|
||
std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() { | ||
static std::unordered_map<std::string, void *> fn_table; | ||
return fn_table; | ||
} | ||
|
||
${Type}::${Type}() | ||
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} | ||
|
||
Allocator* ${Type}::allocator() const { | ||
AT_ERROR("allocator is not implemented for ${Type}"); | ||
} | ||
|
||
Device ${Type}::getDeviceFromPtr(void * data) const { | ||
return DeviceType::${DeviceType}; | ||
} | ||
|
||
std::unique_ptr<Generator> ${Type}::generator() const { | ||
AT_ERROR("generator is not implemented for ${Type}"); | ||
} | ||
|
||
ScalarType ${Type}::scalarType() const { | ||
AT_ERROR("scalarType is not implemented for ${Type}"); | ||
} | ||
|
||
caffe2::TypeMeta ${Type}::typeMeta() const { | ||
AT_ERROR("typeMeta is not implemented for ${Type}"); | ||
} | ||
|
||
Backend ${Type}::backend() const { | ||
return Backend::${Backend}; | ||
} | ||
|
||
const char * ${Type}::toString() const { | ||
return "${Type}"; | ||
} | ||
|
||
TypeID ${Type}::ID() const { | ||
return ${TypeID}; | ||
} | ||
|
||
size_t ${Type}::elementSizeInBytes() const { | ||
AT_ERROR("elementSizeInBytes is not implemented for ${Type}"); | ||
} | ||
|
||
${type_method_definitions} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#pragma once | ||
#include <ATen/TypeDefault.h> | ||
|
||
namespace at { | ||
|
||
// This dispatch class holds static map in which function pointers are | ||
// registered by schema. | ||
// TODO: Check for invalid schemas prior to registration. | ||
struct CAFFE2_API ${Type}Dispatch { | ||
template<typename FnPtr> | ||
static FnPtr get_function(const std::string& schema) { | ||
auto & fn_table = get_fn_table(); | ||
auto it = fn_table.find(schema); | ||
if (it != fn_table.end()) { | ||
return reinterpret_cast<FnPtr>(it->second); | ||
} | ||
AT_ERROR("No function registered for schema: ", schema); | ||
} | ||
|
||
template<typename FnPtr> | ||
static void register_function(const std::string& schema, FnPtr fn) { | ||
auto & fn_table = get_fn_table(); | ||
if (fn_table.find(schema) != fn_table.end()) { | ||
AT_ERROR("Function already registered for schema: ", schema); | ||
} | ||
fn_table[schema] = reinterpret_cast<void *>(fn); | ||
} | ||
|
||
static std::unordered_map<std::string, void *>& get_fn_table(); | ||
}; | ||
|
||
struct CAFFE2_API ${Type} : public TypeDefault { | ||
explicit ${Type}(); | ||
|
||
Allocator* allocator() const override; | ||
Device getDeviceFromPtr(void * data) const override; | ||
std::unique_ptr<Generator> generator() const override; | ||
|
||
virtual ScalarType scalarType() const override; | ||
virtual caffe2::TypeMeta typeMeta() const override; | ||
virtual Backend backend() const override; | ||
virtual const char * toString() const override; | ||
virtual size_t elementSizeInBytes() const override; | ||
virtual TypeID ID() const override; | ||
|
||
${type_method_declarations} | ||
}; | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#include <gtest/gtest.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/NativeFunctions.h> | ||
#include <ATen/ExtensionBackendRegistration.h> | ||
|
||
using namespace at; | ||
|
||
static int test_int; | ||
|
||
Tensor empty_override(IntList size, const TensorOptions & options) { | ||
test_int = 1; | ||
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>( | ||
Storage( | ||
caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false), | ||
MSNPUTensorId(), | ||
false); | ||
return Tensor(std::move(tensor_impl)); | ||
} | ||
|
||
Tensor empty_like_override(const Tensor & self, const TensorOptions & options) { | ||
test_int = 2; | ||
return self; | ||
} | ||
|
||
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { | ||
test_int = 3; | ||
return a; | ||
} | ||
|
||
TEST(BackendExtensionTest, TestRegisterOp) { | ||
EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU)); | ||
register_extension_backend_op( | ||
Backend::MSNPU, | ||
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override); | ||
Tensor a = empty({5, 5}, at::kMSNPU); | ||
ASSERT_EQ(a.device().type(), at::kMSNPU); | ||
ASSERT_EQ(a.device().index(), 1); | ||
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>()); | ||
ASSERT_EQ(test_int, 1); | ||
|
||
EXPECT_ANY_THROW(empty_like(a, at::kMSNPU)); | ||
register_extension_backend_op( | ||
Backend::MSNPU, | ||
"empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override); | ||
Tensor b = empty_like(a, at::kMSNPU); | ||
ASSERT_EQ(test_int, 2); | ||
|
||
EXPECT_ANY_THROW(add(a, b)); | ||
register_extension_backend_op( | ||
Backend::MSNPU, | ||
"add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override); | ||
add(a, b); | ||
ASSERT_EQ(test_int, 3); | ||
|
||
// Ensure that non-MSNPU operator still works | ||
Tensor d = empty({5, 5}, at::kCPU); | ||
ASSERT_EQ(d.device().type(), at::kCPU); | ||
|
||
// Attempt to register on a schema that has already has a function | ||
EXPECT_ANY_THROW( | ||
register_extension_backend_op( | ||
Backend::MSNPU, | ||
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override) | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.