Skip to content

Commit

Permalink
Added at::Dimname (#21280)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#21280
ghimport-source-id: 921848326e4828ffd422868be26c409c6490e1ab

Differential Revision: D15698516

Pulled By: zou3519

fbshipit-source-id: 502b9b019d51dd46327e6caf2af69aa520c70cb6
  • Loading branch information
zou3519 authored and facebook-github-bot committed Jun 7, 2019
1 parent e27c2f1 commit 4727685
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 0 deletions.
61 changes: 61 additions & 0 deletions aten/src/ATen/Dimname.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/Dimname.h>
#include <c10/util/Exception.h>

namespace at {

bool is_valid_identifier(const std::string& name) {
std::locale loc;
if (name.length() == 0) {
return false;
}
for (auto it = name.begin(); it != name.end(); ++it) {
if (std::isalpha(*it, loc) || *it == '_') {
continue;
}
return false;
}
return true;
}

static void check_valid_identifier(const std::string& name) {
TORCH_CHECK(
is_valid_identifier(name),
"A valid identifier must contain alphabetical characters and/or underscore, got: '",
name, "'.");
}

Dimname Dimname::fromSymbol(Symbol name) {
TORCH_INTERNAL_ASSERT(name.is_dimname());
if (name == kWildcard) {
return Dimname::wildcard();
}
const std::string delimiter = ".";
const std::string str(name.toUnqualString());
auto it = str.find(delimiter);

// Check for normal name
if (it == std::string::npos) {
check_valid_identifier(str);
return Dimname(name);
}

// Check for tagged name
auto second_dot = str.find(delimiter, it + 1);
TORCH_CHECK(
second_dot == std::string::npos,
"Invalid name '", str, "': A tagged name can only contain one '.'");
auto untagged_name = str.substr(0, it);
auto tag = str.substr(it + 1);
check_valid_identifier(untagged_name);
check_valid_identifier(tag);
return Dimname(NameType::TAGGED, name, Symbol::dimname(untagged_name));
}

Dimname Dimname::wildcard() {
static Dimname result(NameType::WILDCARD, kWildcard, kWildcard);
return result;
}

} // namespace at
#endif
33 changes: 33 additions & 0 deletions aten/src/ATen/Dimname.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once
#ifdef NAMEDTENSOR_ENABLED

#include <ATen/core/interned_strings.h>

namespace at {

enum class NameType: uint8_t { NORMAL, WILDCARD, TAGGED };

struct CAFFE2_API Dimname {
static Dimname fromSymbol(Symbol name);
static Dimname wildcard();

NameType type() const { return type_; }
Symbol name() const { return name_; }
Symbol untagged_name() const { return untagged_name_; }

private:
Dimname(Symbol name)
: untagged_name_(name), name_(name), type_(NameType::NORMAL) {}
Dimname(NameType type, Symbol name, Symbol untagged_name)
: untagged_name_(untagged_name), name_(name), type_(type) {}
Symbol untagged_name_;
Symbol name_;
NameType type_;
// Will need more fields for other special name types.
};

static Symbol kWildcard = Symbol::dimname("*");
bool CAFFE2_API is_valid_identifier(const std::string& name);

} // namespace at
#endif
14 changes: 14 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace c10 {
_(namespaces, scope) \
_(namespaces, user) \
_(namespaces, _caffe2) \
_(namespaces, dimname) \
_(namespaces, namespaces) \
_(prim, Assign) \
_(prim, BroadcastingChunk) \
Expand Down Expand Up @@ -204,6 +205,7 @@ namespace c10 {
_(namespaces, scope) \
_(namespaces, user) \
_(namespaces, _caffe2) \
_(namespaces, dimname) \
_(namespaces, namespaces)
#endif

Expand Down Expand Up @@ -272,6 +274,9 @@ struct CAFFE2_API Symbol {
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
static Symbol caffe2(const std::string & s);
#ifdef NAMEDTENSOR_ENABLED
static Symbol dimname(const std::string & s);
#endif
// TODO: eliminate me
static Symbol scope(const std::string & s);

Expand All @@ -281,6 +286,9 @@ struct CAFFE2_API Symbol {
bool is_onnx() const;
bool is_user() const;
bool is_caffe2() const;
#ifdef NAMEDTENSOR_ENABLED
bool is_dimname() const;
#endif

// So we can switch on this
constexpr operator unique_t() const {
Expand Down Expand Up @@ -341,12 +349,18 @@ inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualStri
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); }
#ifdef NAMEDTENSOR_ENABLED
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
#endif
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
inline bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
#ifdef NAMEDTENSOR_ENABLED
inline bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }
#endif

} // namespace c10

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/apply_utils_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Dimname_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/half_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/wrapdim_test.cpp
Expand Down
57 changes: 57 additions & 0 deletions aten/src/ATen/test/Dimname_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifdef NAMEDTENSOR_ENABLED
#include <gtest/gtest.h>

#include <ATen/Dimname.h>
#include <c10/util/Exception.h>

using at::is_valid_identifier;
using at::NameType;
using at::Symbol;
using at::Dimname;

TEST(DimnameTest, isValidIdentifier) {
ASSERT_TRUE(is_valid_identifier("a"));
ASSERT_TRUE(is_valid_identifier("batch"));
ASSERT_TRUE(is_valid_identifier("N"));
ASSERT_TRUE(is_valid_identifier("CHANNELS"));
ASSERT_TRUE(is_valid_identifier("foo_bar_baz"));

ASSERT_FALSE(is_valid_identifier(""));
ASSERT_FALSE(is_valid_identifier(" "));
ASSERT_FALSE(is_valid_identifier(" a "));
ASSERT_FALSE(is_valid_identifier("batch1"));
ASSERT_FALSE(is_valid_identifier("foo_bar_1"));
ASSERT_FALSE(is_valid_identifier("?"));
ASSERT_FALSE(is_valid_identifier("-"));
}

TEST(DimnameTest, wildcardName) {
Dimname wildcard = Dimname::wildcard();
ASSERT_EQ(wildcard.type(), NameType::WILDCARD);
ASSERT_EQ(wildcard.name(), Symbol::dimname("*"));
ASSERT_EQ(wildcard.untagged_name(), Symbol::dimname("*"));
}

TEST(DimnameTest, createNormalName) {
auto foo = Symbol::dimname("foo");
auto dimname = Dimname::fromSymbol(foo);
ASSERT_EQ(dimname.type(), NameType::NORMAL);
ASSERT_EQ(dimname.name(), foo);
ASSERT_EQ(dimname.untagged_name(), foo);

ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error);
}

TEST(DimnameTest, createTaggedName) {
auto foo_bar = Symbol::dimname("foo.bar");
auto foo = Symbol::dimname("foo");
auto dimname = Dimname::fromSymbol(foo_bar);
ASSERT_EQ(dimname.type(), NameType::TAGGED);
ASSERT_EQ(dimname.name(), foo_bar);
ASSERT_EQ(dimname.untagged_name(), foo);

ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname(".bar")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.bar.baz")), c10::Error);
}
#endif
1 change: 1 addition & 0 deletions aten/tools/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ VALGRIND=${VALGRIND:=ON}
./extension_backend_test
./xla_tensor_test
./tensor_iterator_test
./Dimname_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
Expand Down

0 comments on commit 4727685

Please sign in to comment.