Skip to content

Commit

Permalink
Add XLA / TPU device type, backend type and type id (#16763)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#16763

Replicate the easy bits in pytorch/pytorch#15153 with TPU / XLA instead of MSNPU. Also don't initialize the storage for XLA tensors for now.
Pull Request resolved: pytorch/pytorch#16585

Reviewed By: ezyang

Differential Revision: D13912118

Pulled By: gchanan

fbshipit-source-id: 4889177e2478768fb281ed075b71146d1d850bd9
  • Loading branch information
asuhan authored and facebook-github-bot committed Feb 5, 2019
1 parent 6efa40e commit 9811a42
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 7 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ enum class TypeID {
SparseCUDALong,
SparseCUDAShort,
MSNPU,
XLA,
CPUComplexFloat,
CPUComplexDouble,
CUDAComplexFloat,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def check_all_files_written(self):

backends = ['CPU', 'CUDA']
densities = ['Dense', 'Sparse']
extension_backends = ['MSNPU']
extension_backends = ['MSNPU', 'XLA']

# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp)

list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
Expand Down
34 changes: 34 additions & 0 deletions aten/src/ATen/test/xla_tensor_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <gtest/gtest.h>

#include <ATen/ATen.h>

using namespace at;

void XLAFree(void *ptr) {
free(ptr);
}

void* XLAMalloc(ptrdiff_t size) {
return malloc(size);
}

struct XLAAllocator final : public at::Allocator {
at::DataPtr allocate(size_t size) const override {
auto* ptr = XLAMalloc(size);
return {ptr, ptr, &XLAFree, at::DeviceType::XLA};
}
at::DeleterFnPtr raw_deleter() const override {
return &XLAFree;
}
};

TEST(XlaTensorTest, TestNoStorage) {
XLAAllocator allocator;
auto storage = Storage(caffe2::TypeMeta::Make<float>(), 0, &allocator, true);
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
std::move(storage),
XLATensorId(),
/*is_variable=*/false);
at::Tensor t(std::move(tensor_impl));
ASSERT_TRUE(t.device() == DeviceType::XLA);
}
1 change: 1 addition & 0 deletions aten/tools/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ VALGRIND=${VALGRIND:=ON}
./tensor_interop_test
./undefined_tensor_test
./extension_backend_test
./xla_tensor_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
Expand Down
18 changes: 17 additions & 1 deletion c10/core/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace c10 {
* would make sense in your use case. If it doesn't make sense, maybe
* you want DeviceType.
*/
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions };
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, NumOptions };

static inline Backend toSparse(Backend b) {
switch (b) {
Expand Down Expand Up @@ -51,6 +51,8 @@ static inline Backend toDense(Backend b) {
return Backend::HIP;
case Backend::MSNPU:
return Backend::MSNPU;
case Backend::XLA:
return Backend::XLA;
case Backend::SparseCPU:
return Backend::CPU;
case Backend::SparseCUDA:
Expand All @@ -71,6 +73,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
return Backend::HIP;
} else if (t == MSNPUTensorId()) {
return Backend::MSNPU;
} else if (t == XLATensorId()) {
return Backend::XLA;
} else if (t == SparseCPUTensorId()) {
return Backend::SparseCPU;
} else if (t == SparseCUDATensorId()) {
Expand All @@ -94,6 +98,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) {
return HIPTensorId();
case Backend::MSNPU:
return MSNPUTensorId();
case Backend::XLA:
return XLATensorId();
case Backend::SparseCPU:
return SparseCPUTensorId();
case Backend::SparseCUDA:
Expand All @@ -117,6 +123,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::HIP;
case Backend::MSNPU:
return DeviceType::MSNPU;
case Backend::XLA:
return DeviceType::XLA;
case Backend::SparseCPU:
return DeviceType::CPU;
case Backend::SparseCUDA:
Expand All @@ -140,6 +148,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) {
return Backend::HIP;
case DeviceType::MSNPU:
return Backend::MSNPU;
case DeviceType::XLA:
return Backend::XLA;
default:
AT_ERROR("Unknown device type ", d);
}
Expand All @@ -160,6 +170,7 @@ static inline Backend backendToCPU(Backend b) {
case Backend::SparseHIP:
return Backend::SparseCPU;
case Backend::MSNPU:
case Backend::XLA:
return Backend::CPU;
case Backend::Undefined:
return Backend::Undefined;
Expand All @@ -174,6 +185,7 @@ static inline Backend backendToCUDA(Backend b) {
case Backend::CUDA:
case Backend::HIP:
case Backend::MSNPU:
case Backend::XLA:
return Backend::CUDA;
case Backend::SparseCPU:
case Backend::SparseCUDA:
Expand All @@ -192,6 +204,7 @@ static inline Backend backendToHIP(Backend b) {
case Backend::CUDA:
case Backend::HIP:
case Backend::MSNPU:
case Backend::XLA:
return Backend::HIP;
case Backend::SparseCPU:
case Backend::SparseCUDA:
Expand All @@ -208,6 +221,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
constexpr DeviceType kXLA = DeviceType::XLA;

static inline const char* toString(Backend b) {
switch (b) {
Expand All @@ -219,6 +233,8 @@ static inline const char* toString(Backend b) {
return "HIP";
case Backend::MSNPU:
return "MSNPU";
case Backend::XLA:
return "XLA";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:
Expand Down
3 changes: 3 additions & 0 deletions c10/core/DeviceType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "fpga" : "FPGA";
case DeviceType::MSNPU:
return lower_case ? "msnpu" : "MSNPU";
case DeviceType::XLA:
return lower_case ? "xla" : "XLA";
default:
AT_ERROR(
"Unknown device: ",
Expand Down Expand Up @@ -56,6 +58,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::HIP:
case DeviceType::FPGA:
case DeviceType::MSNPU:
case DeviceType::XLA:
return true;
default:
return false;
Expand Down
3 changes: 2 additions & 1 deletion c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ enum class DeviceType : int16_t {
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MSNPU = 8, // MSNPU
XLA = 9, // XLA / TPU
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 9,
COMPILE_TIME_MAX_DEVICE_TYPES = 10,
ONLY_FOR_TEST = 20901, // This device type is only for test.
};

Expand Down
5 changes: 3 additions & 2 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
int64_t get_device() const {
// NB: This method is not virtual and tries to avoid dispatches in the common case for perf.
const auto tid = type_id();
if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() || tid == XLATensorId()) {
// TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch.
return storage().device().index();
}
Expand All @@ -369,7 +369,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// TODO: This is a little convoluted so it would be good to investigate
// caching device on TensorImpl (#12934) to speed up device() calls in all cases.
const auto tid = type_id();
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() ||
tid == XLATensorId()) {
// NB: storage(), not storage_, b/c of Variable.
const auto& mystorage = storage();
if (mystorage) {
Expand Down
4 changes: 4 additions & 0 deletions c10/core/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) {
return HIPTensorId();
case DeviceType::MSNPU:
return MSNPUTensorId();
case DeviceType::XLA:
return XLATensorId();
default:
AT_ERROR("Unsupported device type for dense layout: ", options.device().type());
}
Expand Down Expand Up @@ -549,6 +551,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) {
return DeviceType::HIP;
} else if (tid == MSNPUTensorId()) {
return DeviceType::MSNPU;
} else if (tid == XLATensorId()) {
return DeviceType::XLA;
} else if (tid == SparseCPUTensorId()) {
return DeviceType::CPU;
} else if (tid == SparseCUDATensorId()) {
Expand Down
1 change: 1 addition & 0 deletions c10/core/TensorTypeIdRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ C10_DEFINE_TENSOR_TYPE(IDEEPTensorId);
C10_DEFINE_TENSOR_TYPE(HIPTensorId);
C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId);
C10_DEFINE_TENSOR_TYPE(MSNPUTensorId);
C10_DEFINE_TENSOR_TYPE(XLATensorId);

} // namespace c10
1 change: 1 addition & 0 deletions c10/core/TensorTypeIdRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only
C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported
C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only
C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only
C10_DECLARE_TENSOR_TYPE(XLATensorId); // PyTorch only

} // namespace c10

Expand Down
3 changes: 2 additions & 1 deletion caffe2/proto/caffe2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ enum DeviceTypeProto {
PROTO_HIP = 6; // AMD HIP
PROTO_FPGA = 7; // FPGA
PROTO_MSNPU = 8; // MSNPU
PROTO_XLA = 9; // XLA / TPU
// Change the following number if you add more devices in the code.
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9;
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test.
}

Expand Down

0 comments on commit 9811a42

Please sign in to comment.