diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index c23ac01b9d..94eae9b0af 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -75,6 +75,7 @@ enum class TypeID { SparseCUDALong, SparseCUDAShort, MSNPU, + XLA, CPUComplexFloat, CPUComplexDouble, CUDAComplexFloat, diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index c8207710bb..9f5fad1571 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -174,7 +174,7 @@ def check_all_files_written(self): backends = ['CPU', 'CUDA'] densities = ['Dense', 'Sparse'] -extension_backends = ['MSNPU'] +extension_backends = ['MSNPU', 'XLA'] # scalar_name, c_type, accreal, th_scalar_type, is_floating_type scalar_types = [ diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 72a98f7ace..4e8b955c5a 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -21,7 +21,8 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu diff --git a/aten/src/ATen/test/xla_tensor_test.cpp b/aten/src/ATen/test/xla_tensor_test.cpp new file mode 100644 index 0000000000..030461aa9c --- /dev/null +++ b/aten/src/ATen/test/xla_tensor_test.cpp @@ -0,0 +1,34 @@ +#include + +#include + +using namespace at; + +void XLAFree(void *ptr) { + free(ptr); +} + +void* XLAMalloc(ptrdiff_t size) { + return malloc(size); +} + +struct XLAAllocator final : public at::Allocator { + at::DataPtr allocate(size_t size) const override { + auto* ptr = XLAMalloc(size); + return {ptr, ptr, &XLAFree, at::DeviceType::XLA}; + } + at::DeleterFnPtr raw_deleter() const override { + return &XLAFree; + } +}; + +TEST(XlaTensorTest, TestNoStorage) { + XLAAllocator allocator; + auto storage = Storage(caffe2::TypeMeta::Make(), 0, &allocator, true); + auto tensor_impl = c10::make_intrusive( + std::move(storage), + XLATensorId(), + /*is_variable=*/false); + at::Tensor t(std::move(tensor_impl)); + ASSERT_TRUE(t.device() == DeviceType::XLA); +} diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index e2df276220..bb76979b6a 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -18,6 +18,7 @@ VALGRIND=${VALGRIND:=ON} ./tensor_interop_test ./undefined_tensor_test ./extension_backend_test +./xla_tensor_test if [[ -x ./cudnn_test ]]; then ./cudnn_test fi diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 54a22cc499..ddb1dc4f5b 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -20,7 +20,7 @@ namespace c10 { * would make sense in your use case. If it doesn't make sense, maybe * you want DeviceType. */ -enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions }; +enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, NumOptions }; static inline Backend toSparse(Backend b) { switch (b) { @@ -51,6 +51,8 @@ static inline Backend toDense(Backend b) { return Backend::HIP; case Backend::MSNPU: return Backend::MSNPU; + case Backend::XLA: + return Backend::XLA; case Backend::SparseCPU: return Backend::CPU; case Backend::SparseCUDA: @@ -71,6 +73,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) { return Backend::HIP; } else if (t == MSNPUTensorId()) { return Backend::MSNPU; + } else if (t == XLATensorId()) { + return Backend::XLA; } else if (t == SparseCPUTensorId()) { return Backend::SparseCPU; } else if (t == SparseCUDATensorId()) { @@ -94,6 +98,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) { return HIPTensorId(); case Backend::MSNPU: return MSNPUTensorId(); + case Backend::XLA: + return XLATensorId(); case Backend::SparseCPU: return SparseCPUTensorId(); case Backend::SparseCUDA: @@ -117,6 +123,8 @@ static inline DeviceType backendToDeviceType(Backend b) { return DeviceType::HIP; case Backend::MSNPU: return DeviceType::MSNPU; + case Backend::XLA: + return DeviceType::XLA; case Backend::SparseCPU: return DeviceType::CPU; case Backend::SparseCUDA: @@ -140,6 +148,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) { return Backend::HIP; case DeviceType::MSNPU: return Backend::MSNPU; + case DeviceType::XLA: + return Backend::XLA; default: AT_ERROR("Unknown device type ", d); } @@ -160,6 +170,7 @@ static inline Backend backendToCPU(Backend b) { case Backend::SparseHIP: return Backend::SparseCPU; case Backend::MSNPU: + case Backend::XLA: return Backend::CPU; case Backend::Undefined: return Backend::Undefined; @@ -174,6 +185,7 @@ static inline Backend backendToCUDA(Backend b) { case Backend::CUDA: case Backend::HIP: case Backend::MSNPU: + case Backend::XLA: return Backend::CUDA; case Backend::SparseCPU: case Backend::SparseCUDA: @@ -192,6 +204,7 @@ static inline Backend backendToHIP(Backend b) { case Backend::CUDA: case Backend::HIP: case Backend::MSNPU: + case Backend::XLA: return Backend::HIP; case Backend::SparseCPU: case Backend::SparseCUDA: @@ -208,6 +221,7 @@ constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kHIP = DeviceType::HIP; constexpr DeviceType kMSNPU = DeviceType::MSNPU; +constexpr DeviceType kXLA = DeviceType::XLA; static inline const char* toString(Backend b) { switch (b) { @@ -219,6 +233,8 @@ static inline const char* toString(Backend b) { return "HIP"; case Backend::MSNPU: return "MSNPU"; + case Backend::XLA: + return "XLA"; case Backend::SparseCPU: return "SparseCPU"; case Backend::SparseCUDA: diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index fd24b70113..017267cd97 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -25,6 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "fpga" : "FPGA"; case DeviceType::MSNPU: return lower_case ? "msnpu" : "MSNPU"; + case DeviceType::XLA: + return lower_case ? "xla" : "XLA"; default: AT_ERROR( "Unknown device: ", @@ -56,6 +58,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::HIP: case DeviceType::FPGA: case DeviceType::MSNPU: + case DeviceType::XLA: return true; default: return false; diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index de7f387f04..f9038c360c 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -22,11 +22,12 @@ enum class DeviceType : int16_t { HIP = 6, // AMD HIP FPGA = 7, // FPGA MSNPU = 8, // MSNPU + XLA = 9, // XLA / TPU // NB: If you add more devices: // - Change the implementations of DeviceTypeName and isValidDeviceType // in DeviceType.cpp // - Change the number below - COMPILE_TIME_MAX_DEVICE_TYPES = 9, + COMPILE_TIME_MAX_DEVICE_TYPES = 10, ONLY_FOR_TEST = 20901, // This device type is only for test. }; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 3e258519b8..1ee1e794a5 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { int64_t get_device() const { // NB: This method is not virtual and tries to avoid dispatches in the common case for perf. const auto tid = type_id(); - if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) { + if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() || tid == XLATensorId()) { // TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch. return storage().device().index(); } @@ -369,7 +369,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // TODO: This is a little convoluted so it would be good to investigate // caching device on TensorImpl (#12934) to speed up device() calls in all cases. const auto tid = type_id(); - if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) { + if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() || + tid == XLATensorId()) { // NB: storage(), not storage_, b/c of Variable. const auto& mystorage = storage(); if (mystorage) { diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index cd2b464c91..a91712a74e 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -511,6 +511,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) { return HIPTensorId(); case DeviceType::MSNPU: return MSNPUTensorId(); + case DeviceType::XLA: + return XLATensorId(); default: AT_ERROR("Unsupported device type for dense layout: ", options.device().type()); } @@ -549,6 +551,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) { return DeviceType::HIP; } else if (tid == MSNPUTensorId()) { return DeviceType::MSNPU; + } else if (tid == XLATensorId()) { + return DeviceType::XLA; } else if (tid == SparseCPUTensorId()) { return DeviceType::CPU; } else if (tid == SparseCUDATensorId()) { diff --git a/c10/core/TensorTypeIdRegistration.cpp b/c10/core/TensorTypeIdRegistration.cpp index 2333d04606..ed80ac0e8f 100644 --- a/c10/core/TensorTypeIdRegistration.cpp +++ b/c10/core/TensorTypeIdRegistration.cpp @@ -69,5 +69,6 @@ C10_DEFINE_TENSOR_TYPE(IDEEPTensorId); C10_DEFINE_TENSOR_TYPE(HIPTensorId); C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId); C10_DEFINE_TENSOR_TYPE(MSNPUTensorId); +C10_DEFINE_TENSOR_TYPE(XLATensorId); } // namespace c10 diff --git a/c10/core/TensorTypeIdRegistration.h b/c10/core/TensorTypeIdRegistration.h index 9d7d36eed8..1617ea28ad 100644 --- a/c10/core/TensorTypeIdRegistration.h +++ b/c10/core/TensorTypeIdRegistration.h @@ -108,6 +108,7 @@ C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only +C10_DECLARE_TENSOR_TYPE(XLATensorId); // PyTorch only } // namespace c10 diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 210e55f37f..ae15521401 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -179,8 +179,9 @@ enum DeviceTypeProto { PROTO_HIP = 6; // AMD HIP PROTO_FPGA = 7; // FPGA PROTO_MSNPU = 8; // MSNPU + PROTO_XLA = 9; // XLA / TPU // Change the following number if you add more devices in the code. - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9; + PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10; PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test. }