forked from snuspl/nimble
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: pytorch/pytorch#21280 ghimport-source-id: 921848326e4828ffd422868be26c409c6490e1ab Differential Revision: D15698516 Pulled By: zou3519 fbshipit-source-id: 502b9b019d51dd46327e6caf2af69aa520c70cb6
- Loading branch information
1 parent
e27c2f1
commit 4727685
Showing
6 changed files
with
167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#ifdef NAMEDTENSOR_ENABLED | ||
#include <ATen/Dimname.h> | ||
#include <c10/util/Exception.h> | ||
|
||
namespace at { | ||
|
||
bool is_valid_identifier(const std::string& name) { | ||
std::locale loc; | ||
if (name.length() == 0) { | ||
return false; | ||
} | ||
for (auto it = name.begin(); it != name.end(); ++it) { | ||
if (std::isalpha(*it, loc) || *it == '_') { | ||
continue; | ||
} | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
static void check_valid_identifier(const std::string& name) { | ||
TORCH_CHECK( | ||
is_valid_identifier(name), | ||
"A valid identifier must contain alphabetical characters and/or underscore, got: '", | ||
name, "'."); | ||
} | ||
|
||
Dimname Dimname::fromSymbol(Symbol name) { | ||
TORCH_INTERNAL_ASSERT(name.is_dimname()); | ||
if (name == kWildcard) { | ||
return Dimname::wildcard(); | ||
} | ||
const std::string delimiter = "."; | ||
const std::string str(name.toUnqualString()); | ||
auto it = str.find(delimiter); | ||
|
||
// Check for normal name | ||
if (it == std::string::npos) { | ||
check_valid_identifier(str); | ||
return Dimname(name); | ||
} | ||
|
||
// Check for tagged name | ||
auto second_dot = str.find(delimiter, it + 1); | ||
TORCH_CHECK( | ||
second_dot == std::string::npos, | ||
"Invalid name '", str, "': A tagged name can only contain one '.'"); | ||
auto untagged_name = str.substr(0, it); | ||
auto tag = str.substr(it + 1); | ||
check_valid_identifier(untagged_name); | ||
check_valid_identifier(tag); | ||
return Dimname(NameType::TAGGED, name, Symbol::dimname(untagged_name)); | ||
} | ||
|
||
Dimname Dimname::wildcard() { | ||
static Dimname result(NameType::WILDCARD, kWildcard, kWildcard); | ||
return result; | ||
} | ||
|
||
} // namespace at | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#pragma once | ||
#ifdef NAMEDTENSOR_ENABLED | ||
|
||
#include <ATen/core/interned_strings.h> | ||
|
||
namespace at { | ||
|
||
enum class NameType: uint8_t { NORMAL, WILDCARD, TAGGED }; | ||
|
||
struct CAFFE2_API Dimname { | ||
static Dimname fromSymbol(Symbol name); | ||
static Dimname wildcard(); | ||
|
||
NameType type() const { return type_; } | ||
Symbol name() const { return name_; } | ||
Symbol untagged_name() const { return untagged_name_; } | ||
|
||
private: | ||
Dimname(Symbol name) | ||
: untagged_name_(name), name_(name), type_(NameType::NORMAL) {} | ||
Dimname(NameType type, Symbol name, Symbol untagged_name) | ||
: untagged_name_(untagged_name), name_(name), type_(type) {} | ||
Symbol untagged_name_; | ||
Symbol name_; | ||
NameType type_; | ||
// Will need more fields for other special name types. | ||
}; | ||
|
||
static Symbol kWildcard = Symbol::dimname("*"); | ||
bool CAFFE2_API is_valid_identifier(const std::string& name); | ||
|
||
} // namespace at | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#ifdef NAMEDTENSOR_ENABLED | ||
#include <gtest/gtest.h> | ||
|
||
#include <ATen/Dimname.h> | ||
#include <c10/util/Exception.h> | ||
|
||
using at::is_valid_identifier; | ||
using at::NameType; | ||
using at::Symbol; | ||
using at::Dimname; | ||
|
||
TEST(DimnameTest, isValidIdentifier) { | ||
ASSERT_TRUE(is_valid_identifier("a")); | ||
ASSERT_TRUE(is_valid_identifier("batch")); | ||
ASSERT_TRUE(is_valid_identifier("N")); | ||
ASSERT_TRUE(is_valid_identifier("CHANNELS")); | ||
ASSERT_TRUE(is_valid_identifier("foo_bar_baz")); | ||
|
||
ASSERT_FALSE(is_valid_identifier("")); | ||
ASSERT_FALSE(is_valid_identifier(" ")); | ||
ASSERT_FALSE(is_valid_identifier(" a ")); | ||
ASSERT_FALSE(is_valid_identifier("batch1")); | ||
ASSERT_FALSE(is_valid_identifier("foo_bar_1")); | ||
ASSERT_FALSE(is_valid_identifier("?")); | ||
ASSERT_FALSE(is_valid_identifier("-")); | ||
} | ||
|
||
TEST(DimnameTest, wildcardName) { | ||
Dimname wildcard = Dimname::wildcard(); | ||
ASSERT_EQ(wildcard.type(), NameType::WILDCARD); | ||
ASSERT_EQ(wildcard.name(), Symbol::dimname("*")); | ||
ASSERT_EQ(wildcard.untagged_name(), Symbol::dimname("*")); | ||
} | ||
|
||
TEST(DimnameTest, createNormalName) { | ||
auto foo = Symbol::dimname("foo"); | ||
auto dimname = Dimname::fromSymbol(foo); | ||
ASSERT_EQ(dimname.type(), NameType::NORMAL); | ||
ASSERT_EQ(dimname.name(), foo); | ||
ASSERT_EQ(dimname.untagged_name(), foo); | ||
|
||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error); | ||
} | ||
|
||
TEST(DimnameTest, createTaggedName) { | ||
auto foo_bar = Symbol::dimname("foo.bar"); | ||
auto foo = Symbol::dimname("foo"); | ||
auto dimname = Dimname::fromSymbol(foo_bar); | ||
ASSERT_EQ(dimname.type(), NameType::TAGGED); | ||
ASSERT_EQ(dimname.name(), foo_bar); | ||
ASSERT_EQ(dimname.untagged_name(), foo); | ||
|
||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname(".bar")), c10::Error); | ||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.")), c10::Error); | ||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.bar.baz")), c10::Error); | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters