Skip to content

Commit

Permalink
Add aliasAnalysis to torch::RegisterOperators() (pytorch#21084)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#21084

- Now AliasAnalysisKind can be set using the torch::RegisterOperators() API
- This also allows us to remove the last place in torch::jit::RegisterOperators that didn't use c10 yet.

Reviewed By: dzhulgakov

Differential Revision: D15542097

fbshipit-source-id: ea127ecf051a5c1e567e035692deed44e04faa9e
  • Loading branch information
smessmer authored and facebook-github-bot committed May 31, 2019
1 parent 8055676 commit 384d828
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 97 deletions.
11 changes: 7 additions & 4 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,29 @@ c10::optional<OperatorHandle> Dispatcher::findSchema(const char* operator_name,
return OperatorHandle(found);
}

OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) {
OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options) {
const auto found = findSchema(schema.name().c_str(), schema.overload_name().c_str());
if (found != c10::nullopt) {
if (found->schema() != schema) {
std::ostringstream str;
str << schema << " vs " << found->schema();
TORCH_CHECK(false, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", str.str());
}
if (found->options() != options) {
TORCH_CHECK(false, "Tried to register multiple operators with the same schema but different options: ", toString(schema));
}
return *found;
}

operators_.emplace_back(std::move(schema));
operators_.emplace_back(std::move(schema), std::move(options));
return OperatorHandle(--operators_.end());
}

SchemaRegistrationHandleRAII Dispatcher::registerSchema(FunctionSchema schema) {
SchemaRegistrationHandleRAII Dispatcher::registerSchema(FunctionSchema schema, OperatorOptions options) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);

auto op = findOrRegisterSchema_(std::move(schema));
auto op = findOrRegisterSchema_(std::move(schema), std::move(options));

++op.operatorIterator_->refcount;
if (1 == op.operatorIterator_->refcount) {
Expand Down
12 changes: 8 additions & 4 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class SchemaRegistrationHandleRAII;
class CAFFE2_API Dispatcher final {
private:
struct OperatorDef final {
explicit OperatorDef(FunctionSchema&& schema)
: op(std::move(schema)), refcount(0) {}
explicit OperatorDef(FunctionSchema&& schema, OperatorOptions&& options)
: op(std::move(schema), std::move(options)), refcount(0) {}

impl::OperatorEntry op;
size_t refcount;
Expand All @@ -100,7 +100,7 @@ class CAFFE2_API Dispatcher final {
* object that manages the lifetime of the registration. Once that
* object is destructed, the kernel will be deregistered.
*/
SchemaRegistrationHandleRAII registerSchema(FunctionSchema schema);
SchemaRegistrationHandleRAII registerSchema(FunctionSchema schema, OperatorOptions options);

/**
* Looks for an operator schema with the given name and overload name
Expand Down Expand Up @@ -144,7 +144,7 @@ class CAFFE2_API Dispatcher final {
private:
Dispatcher();

OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options);

void deregisterSchema_(const OperatorHandle& op);

Expand All @@ -169,6 +169,10 @@ class CAFFE2_API OperatorHandle final {
return operatorIterator_->op.schema();
}

const OperatorOptions& options() const {
return operatorIterator_->op.options();
}

private:
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorIterator_(std::move(operatorIterator)) {}
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ namespace {
}
}

OperatorEntry::OperatorEntry(FunctionSchema&& schema)
OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options)
: schema_(std::move(schema))
, dispatchTable_(schema_)
, kernels_(make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>()) {
, kernels_(make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>())
, options_(std::move(options)) {
}

void OperatorEntry::prepareForDeregistration() {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace impl {
// and its dispatch table. This is not part of the public API.
class OperatorEntry final {
public:
explicit OperatorEntry(FunctionSchema&& schema);
explicit OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options);

OperatorEntry(const OperatorEntry&) = delete;
OperatorEntry(OperatorEntry&&) noexcept = delete;
Expand All @@ -34,7 +34,7 @@ class OperatorEntry final {
RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel);
RegistrationHandleRAII registerCatchallKernel(DispatchTableEntry kernel);

OperatorOptions& options() {
const OperatorOptions& options() {
return options_;
}

Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ struct OperatorOptions final {
aliasAnalysisKind_ = v;
}

friend bool operator==(const OperatorOptions& lhs, const OperatorOptions& rhs) {
return lhs.aliasAnalysisKind_ == rhs.aliasAnalysisKind_;
}

friend bool operator!=(const OperatorOptions& lhs, const OperatorOptions& rhs) {
return !(lhs == rhs);
}

private:
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
};
Expand Down
31 changes: 23 additions & 8 deletions aten/src/ATen/core/op_registration/op_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRA
// table deregisters it in the destructor.
class RegisterOperators::OperatorRegistrar final {
public:
explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
: op_(Dispatcher::singleton().registerSchema(std::move(schema))), kernel_registration_handle_(c10::nullopt) {
explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
: op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) {
// either both, kernel and cache_creator, or none must be set.
TORCH_INTERNAL_ASSERT((kernel != nullptr) == static_cast<bool>(cache_creator));

Expand Down Expand Up @@ -124,23 +124,38 @@ void RegisterOperators::checkNoDuplicateKernels_(const FunctionSchema& schema, c
}

void RegisterOperators::registerOp_(FunctionSchema&& schema, Options&& options) {
std::string op_name = schema.name();
std::string overload_name = schema.overload_name();

auto operatorOptions = makeOperatorOptions_(options);

if (0 == options.kernels.size()) {
registerSchemaOnly_(std::move(schema));
registerSchemaOnly_(std::move(schema), std::move(operatorOptions));
} else {
for (auto& kernel : options.kernels) {
registerSchemaAndKernel_(schema, std::move(kernel));
registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions));
}
}

auto op_handle = c10::Dispatcher::singleton().findSchema(op_name.c_str(), overload_name.c_str()).value();
}

OperatorOptions RegisterOperators::makeOperatorOptions_(const RegisterOperators::Options& options) {
OperatorOptions result;
if (options.aliasAnalysisKind_.has_value()) {
result.setAliasAnalysis(*options.aliasAnalysisKind_);
}
return result;
}

void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel) {
void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions) {
TORCH_INTERNAL_ASSERT(kernel.kernel_func != nullptr && static_cast<bool>(kernel.cache_creator_func), "Kernel must be set");

registrars_.emplace_back(std::move(schema), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func));
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func));
}

void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema) {
registrars_.emplace_back(std::move(schema), c10::nullopt, nullptr, nullptr);
void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions) {
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, nullptr, nullptr);
}

RegisterOperators::RegisterOperators() = default;
Expand Down
12 changes: 10 additions & 2 deletions aten/src/ATen/core/op_registration/op_registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ class CAFFE2_API RegisterOperators final {
return std::move(*this).kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(c10::nullopt, std::forward<Lambda>(functor));
}

Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
aliasAnalysisKind_ = aliasAnalysisKind;
return std::move(*this);
}

private:
Options&& kernel(c10::optional<TensorTypeId>&& dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction&& cache_creator, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
KernelRegistrationConfig config;
Expand Down Expand Up @@ -292,6 +298,7 @@ class CAFFE2_API RegisterOperators final {
};

std::vector<KernelRegistrationConfig> kernels;
optional<AliasAnalysisKind> aliasAnalysisKind_;
friend class RegisterOperators;
};

Expand Down Expand Up @@ -398,8 +405,9 @@ class CAFFE2_API RegisterOperators final {
static c10::FunctionSchema inferSchemaFromKernels_(const std::string& opNameStr, const Options& options);
void checkNoDuplicateKernels_(const FunctionSchema& schema, const Options& options);
void registerOp_(FunctionSchema&& schema, Options&& options);
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config);
void registerSchemaOnly_(FunctionSchema&& schema);
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options);
void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options);
static OperatorOptions makeOperatorOptions_(const Options& options);

class OperatorRegistrar;

Expand Down
18 changes: 8 additions & 10 deletions test/cpp/jit/test_alias_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -884,13 +884,10 @@ void testMemoryDAG() {

void testAliasRegistration() {
{
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
RegisterOperators reg({createOperator(
"foo::rand",
[](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
},
opts)});
auto registry = torch::RegisterOperators()
.op("foo::rand", torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor { return at::rand({2, 2}); })
.aliasAnalysis(AliasAnalysisKind::DEFAULT));
const auto rand_op = Symbol::fromQualString("foo::rand");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
Expand All @@ -900,9 +897,10 @@ void testAliasRegistration() {
ASSERT_TRUE(aliasDb.mayAlias(a, b));
}
{
auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
RegisterOperators reg({createOperator(
"foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
auto registry = torch::RegisterOperators()
.op("foo::pure", torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::PURE));
const auto rand_op = Symbol::fromQualString("foo::pure");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
Expand Down
20 changes: 2 additions & 18 deletions torch/csrc/jit/custom_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
template <typename Implementation>
Operator createOperator(
const std::string& schemaOrName,
Implementation&& implementation,
OperatorOptions options = OperatorOptions()) {
Implementation&& implementation) {
using Traits = c10::guts::infer_function_traits_t<Implementation>;
using ArgumentTypes =
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
Expand Down Expand Up @@ -214,8 +213,7 @@ Operator createOperator(
tuple,
typename MakeIndices<kNumberOfArguments>::indices{});
return 0;
},
std::move(options));
});
}

/// Registration class for new operators. Effectively calls
Expand All @@ -239,20 +237,6 @@ struct TORCH_API RegisterOperators {
op(name, std::forward<Implementation>(implementation));
}

/// Creates a new operator from a name and implementation function (function
/// pointer or function object/lambda) using `torch::jit::createOperator`, and
/// then registers the operator.
template <typename Implementation>
RegisterOperators& op(
const std::string& name,
Implementation&& implementation,
OperatorOptions options) {

registerOperator(createOperator(
name, std::forward<Implementation>(implementation), options));
return *this;
}

template <typename Implementation>
RegisterOperators& op(
const std::string& name,
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ void registerOperator(Operator&& op) {
op.schema().name(),
". File a bug to add a case for this operator.\n");
}
if (!aliasAnalysisHasSpecialCaseFor(s) &&
op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
if (op.isC10Op() && !aliasAnalysisHasSpecialCaseFor(s) &&
op.aliasAnalysisKind() == AliasAnalysisKind::DEFAULT) {
AT_ERROR(
"Missing special case in alias analysis for non-schematized"
" operator ",
Expand Down
Loading

0 comments on commit 384d828

Please sign in to comment.