Skip to content

Commit

Permalink
AliasAnalysisKind::CONSERVATIVE/FROM_SCHEMA (pytorch#22175)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#22175

- Rename AliasAnalysisKind::DEFAULT to AliasAnalysisKind::CONSERVATIVE
- Introduce AliasAnalysisKind::FROM_SCHEMA that means the alias annotations of the schema should be honored
- Introduce AliasAnalysisKind::INTERNAL_SPECIAL_CASE to be able to run assertions that internal special cased ops are treated correctly

- aten:: and prim:: ops are not treated as special cases anymore, but just use AliasAnalysisKind::FROM_SCHEMA
- There's a set of assertions to ensure that aten:: and prim:: ops are all correctly set up to use AliasAnalysisKind::FROM_SCHEMA. Once this PR lands and passes all tests, we will remove those assertions and open up for the possibility of different AliasAnalysisKind settings for aten:: and prim:: ops

Differential Revision: D15929595

fbshipit-source-id: 7c6a9d4d29e13b8c9a856062cd6fb3f8a46a2e0d
  • Loading branch information
smessmer authored and facebook-github-bot committed Jul 25, 2019
1 parent b9202d4 commit bbc53bf
Show file tree
Hide file tree
Showing 21 changed files with 1,784 additions and 987 deletions.
19 changes: 17 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,25 @@ class OperatorEntry;
}

enum class AliasAnalysisKind : uint8_t {
DEFAULT, // The most conservative alias analysis type, assumes side-effects
INTERNAL_SPECIAL_CASE,
CONSERVATIVE, // The most conservative alias analysis type, assumes
// side-effects. This is the default analysis.
FROM_SCHEMA,
PURE
};

constexpr inline const char* toString(AliasAnalysisKind aliasAnalysisKind) {
return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE)
? "CONSERVATIVE"
: (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA)
? "FROM_SCHEMA"
: (aliasAnalysisKind == AliasAnalysisKind::PURE)
? "PURE"
: (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE)
? "INTERNAL_SPECIAL_CASE"
: "UNKNOWN";
}

struct OperatorOptions final {
public:
AliasAnalysisKind aliasAnalysis() const {
Expand All @@ -31,7 +46,7 @@ struct OperatorOptions final {
}

private:
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::CONSERVATIVE;
};

} // namespace c10
28 changes: 19 additions & 9 deletions aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,11 @@ struct FunctionSchema {
return is_varret_;
}
bool is_mutable() const {
// see [custom operator aliasing]
const auto kind = Symbol::fromQualString(name_.name);
const auto is_custom_op = !kind.is_aten() && !kind.is_prim();
return is_custom_op ||
std::any_of(
arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
const auto& aliasInfo = arg.alias_info();
return aliasInfo && aliasInfo.value().isWrite();
});
return std::any_of(
arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
const auto& aliasInfo = arg.alias_info();
return aliasInfo && aliasInfo.value().isWrite();
});
}

c10::optional<int> argumentIndexWithName(const std::string& name) const {
Expand Down Expand Up @@ -226,6 +222,20 @@ struct FunctionSchema {
const std::unordered_map<std::string, IValue>& kwargs) const;

void findErrorInKwargs(const std::vector<std::string>& kwargs) const;

bool hasAnyAliasInfo() const {
for (const auto& arg : arguments_) {
if (arg.alias_info().has_value()) {
return true;
}
}
for (const auto& ret : returns_) {
if (ret.alias_info().has_value()) {
return true;
}
}
return false;
}
};

inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
Expand Down
19 changes: 18 additions & 1 deletion aten/src/ATen/core/op_registration/op_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@ void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNam
either<OperatorName, FunctionSchema> schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr);
if (schemaOrName.is_right()) {
// schema was explicitly specified. Check it matches the inferred one and register the op.
checkSchemaAndRegisterOp_(std::move(schemaOrName).right(), std::move(options));

auto schema = std::move(schemaOrName).right();
TORCH_CHECK(
options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA ||
!schema.hasAnyAliasInfo(),
"In operator registration: Tried to register operator ",
schemaOrNameStr,
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");

checkSchemaAndRegisterOp_(std::move(schema), std::move(options));
} else {
// schema wasn't explicitly specified. Take the inferred schema for registering the op.

Expand All @@ -62,6 +71,14 @@ void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNam

checkNoDuplicateKernels_(inferred_schema_with_name, options);

// This would have unexpected behavior since an inferred schema will not
// have aliasing annotations.
TORCH_CHECK(
options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
"In operator registration: Tried to register operator ",
schemaOrNameStr,
" with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");

// Register all kernels with the schema we inferred
registerOp_(std::move(inferred_schema_with_name), std::move(options));
}
Expand Down
205 changes: 188 additions & 17 deletions test/cpp/jit/test_alias_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
namespace torch {
namespace jit {

inline c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
}

// Fixture to set up a graph and make assertions clearer
struct TopoMoveTestFixture {
TopoMoveTestFixture() {
Expand Down Expand Up @@ -284,6 +290,28 @@ Node* insertIf(
}
return if_;
}

template <class Exception, class Functor>
inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
try {
std::forward<Functor>(functor)();
} catch (const Exception& e) {
if (std::string(e.what()).find(expectMessageContains) ==
std::string::npos) {
AT_ERROR(
"Expected error message to contain \"",
expectMessageContains,
"\" but error message was: ",
e.what());
}
return;
}
AT_ERROR(
"Expected to throw exception containing \"",
expectMessageContains,
"\" but didn't throw");
}

} // namespace

void testAliasAnalysis() {
Expand Down Expand Up @@ -477,10 +505,10 @@ void testAliasAnalysis() {
}

void testWriteTracking() {
RegisterOperators reg(
{Operator("prim::creates_alias(Tensor(a) x) -> Tensor(a)", [](Stack& s) {
return 0;
})});
RegisterOperators reg({Operator(
"prim::creates_alias(Tensor(a) x) -> Tensor(a)",
[](Stack& s) { return 0; },
aliasAnalysisFromSchema())});
const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
{
auto graph = std::make_shared<Graph>();
Expand Down Expand Up @@ -839,13 +867,14 @@ graph():
}

void testWildcards() {
RegisterOperators reg(
{Operator(
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
[](Stack& stack) { return 0; }),
Operator("prim::writes(Tensor(z!) a) -> Tensor(a)", [](Stack& stack) {
return 0;
})});
RegisterOperators reg({Operator(
"prim::returns_wildcard(Tensor a) -> Tensor(*)",
[](Stack& stack) { return 0; },
aliasAnalysisFromSchema()),
Operator(
"prim::writes(Tensor(z!) a) -> Tensor(a)",
[](Stack& stack) { return 0; },
aliasAnalysisFromSchema())});
const auto returns_wildcard =
Symbol::fromQualString("prim::returns_wildcard");
const auto writes = Symbol::fromQualString("prim::writes");
Expand Down Expand Up @@ -997,13 +1026,13 @@ void testMemoryDAG() {
void testAliasRegistration() {
{
auto registry = torch::RegisterOperators().op(
"foo::rand",
"foo::rand1",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::DEFAULT));
const auto rand_op = Symbol::fromQualString("foo::rand");
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand1");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
Expand All @@ -1013,18 +1042,160 @@ void testAliasRegistration() {
}
{
auto registry = torch::RegisterOperators().op(
"foo::pure",
"foo::rand2(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
const auto rand_op = Symbol::fromQualString("foo::rand2");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// Conservatively we assume there is a reference
ASSERT_TRUE(aliasDb.mayAlias(a, b));
}
{
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand3(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
},
"Tried to register operator foo::rand3(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
{
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand4(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
},
"Tried to register operator foo::rand4(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
{
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand5",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
},
"Tried to register operator foo::rand5 with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred");
}
{
auto registry = torch::RegisterOperators().op(
"aten::rand6(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("aten::rand6");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema doesn't contain alias information, which means it's pure
// (meh!)
ASSERT_FALSE(aliasDb.mayAlias(a, b));
}
{
auto registry = torch::RegisterOperators().op(
"aten::rand7(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("aten::rand7");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema has an alias reference
ASSERT_TRUE(aliasDb.mayAlias(a, b));
}
{
auto registry = torch::RegisterOperators().op(
"aten::rand8(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
const auto rand_op = Symbol::fromQualString("aten::rand8");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema does not have an alias reference
ASSERT_FALSE(aliasDb.mayAlias(a, b));
}
{
auto registry = torch::RegisterOperators().op(
"foo::rand9",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::PURE));
const auto rand_op = Symbol::fromQualString("foo::rand9");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// The schema is pure, there cannot be any alias
ASSERT_FALSE(aliasDb.mayAlias(a, b));
}
{
auto registry = torch::RegisterOperators().op(
"foo::rand10(Tensor arg1) -> Tensor",
torch::RegisterOperators::options()
.catchAllKernel([](at::Tensor) -> at::Tensor {
return at::rand({2, 2});
})
.aliasAnalysis(AliasAnalysisKind::PURE));
const auto rand_op = Symbol::fromQualString("foo::pure");
const auto rand_op = Symbol::fromQualString("foo::rand10");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->insert(rand_op, {a});
AliasDb aliasDb(graph);
// PURE means there is no reference
// The schema is pure, there cannot be any alias
ASSERT_FALSE(aliasDb.mayAlias(a, b));
}
{
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand11(Tensor(a) arg1) -> Tensor(a)",
torch::RegisterOperators::options()
.catchAllKernel(
[](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::PURE));
},
"Tried to register operator foo::rand11(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
{
expectThrows<c10::Error>(
[] {
torch::RegisterOperators().op(
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",
torch::RegisterOperators::options()
.catchAllKernel(
[](at::Tensor t) -> at::Tensor { return t * 2; })
.aliasAnalysis(AliasAnalysisKind::PURE));
},
"Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
}

} // namespace jit
Expand Down
Loading

0 comments on commit bbc53bf

Please sign in to comment.