From 60d1852c7bba771915ddc6d9dffd93ffaf63e2d5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Nov 2016 22:38:16 +0100 Subject: [PATCH] Major improvements to master-worker mode * Fixed all undefined symbol errors * Implemented storage interface and THStorage class * RPC improvements * Code refactor --- setup.py | 2 + torch/csrc/Module.cpp | 39 ++ torch/csrc/distributed/Module.cpp | 44 +- torch/csrc/distributed/Storage.h | 2 +- torch/csrc/distributed/THDP.h | 2 + torch/csrc/distributed/Tensor.h | 6 +- torch/csrc/distributed/override_macros.h | 1 + torch/csrc/distributed/utils.cpp | 8 +- torch/distributed/__init__.py | 8 +- torch/distributed/remote_types.py | 16 +- torch/lib/THD/CMakeLists.txt | 25 +- torch/lib/THD/base/Storage.hpp | 42 ++ torch/lib/THD/base/TensorTraits.hpp | 127 ----- torch/lib/THD/base/Traits.hpp | 89 +++ torch/lib/THD/base/storages/THStorage.cpp | 11 + torch/lib/THD/base/storages/THStorage.hpp | 51 ++ .../THD/base/storages/generic/THStorage.cpp | 69 +++ .../THD/base/storages/generic/THStorage.hpp | 10 + torch/lib/THD/base/tensors/THTensor.cpp | 2 +- torch/lib/THD/base/tensors/THTensor.hpp | 10 +- .../lib/THD/base/tensors/generic/THTensor.cpp | 2 +- .../lib/THD/base/tensors/generic/THTensor.hpp | 2 +- .../master_worker/common/CommandChannel.cpp | 6 +- .../THD/master_worker/common/Functions.hpp | 31 +- .../lib/THD/master_worker/common/RPC-inl.hpp | 23 +- torch/lib/THD/master_worker/common/RPC.cpp | 15 +- torch/lib/THD/master_worker/common/RPC.hpp | 7 +- torch/lib/THD/master_worker/common/Traits.hpp | 60 ++ torch/lib/THD/master_worker/master/Master.cpp | 1 + .../THD/master_worker/master/THDStorage.cpp | 12 + .../THD/master_worker/master/THDTensor.cpp | 2 +- .../master/generic/THDStorage.cpp | 114 ++-- .../master_worker/master/generic/THDStorage.h | 8 +- .../master/generic/THDTensor.cpp | 522 +++++++++++++++++- .../master_worker/master/generic/THDTensor.h | 6 +- .../lib/THD/master_worker/worker/Dispatch.cpp | 73 ++- torch/lib/THD/master_worker/worker/Worker.cpp | 19 +- torch/lib/THD/master_worker/worker/Worker.hpp | 4 +- .../worker/dispatch/Communication.cpp | 14 + .../master_worker/worker/dispatch/Storage.cpp | 27 + .../master_worker/worker/dispatch/Tensor.cpp | 32 ++ 41 files changed, 1290 insertions(+), 254 deletions(-) create mode 100644 torch/lib/THD/base/Storage.hpp delete mode 100644 torch/lib/THD/base/TensorTraits.hpp create mode 100644 torch/lib/THD/base/Traits.hpp create mode 100644 torch/lib/THD/base/storages/THStorage.cpp create mode 100644 torch/lib/THD/base/storages/THStorage.hpp create mode 100644 torch/lib/THD/base/storages/generic/THStorage.cpp create mode 100644 torch/lib/THD/base/storages/generic/THStorage.hpp create mode 100644 torch/lib/THD/master_worker/common/Traits.hpp create mode 100644 torch/lib/THD/master_worker/master/THDStorage.cpp create mode 100644 torch/lib/THD/master_worker/worker/dispatch/Communication.cpp create mode 100644 torch/lib/THD/master_worker/worker/dispatch/Storage.cpp create mode 100644 torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp diff --git a/setup.py b/setup.py index 697ede054be413..95a64d4516be0c 100644 --- a/setup.py +++ b/setup.py @@ -235,6 +235,8 @@ def run(self): extra_compile_args += ['-DWITH_DISTRIBUTED'] main_sources += [ "torch/csrc/distributed/Module.cpp", + "torch/csrc/distributed/Tensor.cpp", + "torch/csrc/distributed/Storage.cpp", "torch/csrc/distributed/utils.cpp" ] include_dirs += [tmp_install_path + "/include/THD"] diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 912d27fb361d17..e329202a78b56a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -697,6 +697,24 @@ bool THCSPShortTensor_init(PyObject *module); bool THCSPCharTensor_init(PyObject *module); bool THCSPByteTensor_init(PyObject *module); +bool THDPDoubleStorage_init(PyObject *module); +bool THDPFloatStorage_init(PyObject *module); +//bool THDPHalfStorage_init(PyObject *module); +bool THDPLongStorage_init(PyObject *module); +bool THDPIntStorage_init(PyObject *module); +bool THDPShortStorage_init(PyObject *module); +bool THDPCharStorage_init(PyObject *module); +bool THDPByteStorage_init(PyObject *module); + +bool THDPDoubleTensor_init(PyObject *module); +bool THDPFloatTensor_init(PyObject *module); +//bool THDPHalfTensor_init(PyObject *module); +bool THDPLongTensor_init(PyObject *module); +bool THDPIntTensor_init(PyObject *module); +bool THDPShortTensor_init(PyObject *module); +bool THDPCharTensor_init(PyObject *module); +bool THDPByteTensor_init(PyObject *module); + static std::vector methods; #ifdef WITH_DISTRIBUTED @@ -811,6 +829,27 @@ PyMODINIT_FUNC PyInit__C() Py_INCREF(has_cudnn); ASSERT_TRUE(PyModule_AddObject(module, "has_cudnn", has_cudnn) == 0); +#ifdef WITH_DISTRIBUTED + // See comment on CUDA objects + ASSERT_TRUE(THDPDoubleStorage_init(module)); + ASSERT_TRUE(THDPFloatStorage_init(module)); + //ASSERT_TRUE(THDPHalfStorage_init(module)); + ASSERT_TRUE(THDPLongStorage_init(module)); + ASSERT_TRUE(THDPIntStorage_init(module)); + ASSERT_TRUE(THDPShortStorage_init(module)); + ASSERT_TRUE(THDPCharStorage_init(module)); + ASSERT_TRUE(THDPByteStorage_init(module)); + + ASSERT_TRUE(THDPDoubleTensor_init(module)); + ASSERT_TRUE(THDPFloatTensor_init(module)); + //ASSERT_TRUE(THDPHalfTensor_init(module)); + ASSERT_TRUE(THDPLongTensor_init(module)); + ASSERT_TRUE(THDPIntTensor_init(module)); + ASSERT_TRUE(THDPShortTensor_init(module)); + ASSERT_TRUE(THDPCharTensor_init(module)); + ASSERT_TRUE(THDPByteTensor_init(module)); +#endif + THPDefaultGenerator = (THPGenerator*)THPGenerator_New(); ASSERT_TRUE(THPDefaultGenerator != nullptr); ASSERT_TRUE(PyModule_AddObject(module, "default_generator", (PyObject*)THPDefaultGenerator) == 0); diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp index 9fd51dd083b263..2dcb42fd3dc271 100644 --- a/torch/csrc/distributed/Module.cpp +++ b/torch/csrc/distributed/Module.cpp @@ -11,6 +11,31 @@ static std::unordered_map name2channel_type = { {"tcp", THDChannelTCP}, }; +static bool THDPModule_loadClasses(PyObject *module_dict) +{ +#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; } + ASSERT_NOT_NULL(THDPDoubleStorageClass = PyMapping_GetItemString(module_dict, (char*)"DoubleStorage")); + ASSERT_NOT_NULL(THDPFloatStorageClass = PyMapping_GetItemString(module_dict, (char*)"FloatStorage")); + //ASSERT_NOT_NULL(THDPHalfStorageClass = PyMapping_GetItemString(module_dict, (char*)"HalfStorage")); + ASSERT_NOT_NULL(THDPLongStorageClass = PyMapping_GetItemString(module_dict, (char*)"LongStorage")); + ASSERT_NOT_NULL(THDPIntStorageClass = PyMapping_GetItemString(module_dict, (char*)"IntStorage")); + ASSERT_NOT_NULL(THDPShortStorageClass = PyMapping_GetItemString(module_dict, (char*)"ShortStorage")); + ASSERT_NOT_NULL(THDPCharStorageClass = PyMapping_GetItemString(module_dict, (char*)"CharStorage")); + ASSERT_NOT_NULL(THDPByteStorageClass = PyMapping_GetItemString(module_dict, (char*)"ByteStorage")); + + ASSERT_NOT_NULL(THDPDoubleTensorClass = PyMapping_GetItemString(module_dict, (char*)"DoubleTensor")); + //ASSERT_NOT_NULL(THDPHalfTensorClass = PyMapping_GetItemString(module_dict, (char*)"HalfTensor")); + ASSERT_NOT_NULL(THDPFloatTensorClass = PyMapping_GetItemString(module_dict, (char*)"FloatTensor")); + ASSERT_NOT_NULL(THDPLongTensorClass = PyMapping_GetItemString(module_dict, (char*)"LongTensor")); + ASSERT_NOT_NULL(THDPIntTensorClass = PyMapping_GetItemString(module_dict, (char*)"IntTensor")); + ASSERT_NOT_NULL(THDPShortTensorClass = PyMapping_GetItemString(module_dict, (char*)"ShortTensor")); + ASSERT_NOT_NULL(THDPCharTensorClass = PyMapping_GetItemString(module_dict, (char*)"CharTensor")); + ASSERT_NOT_NULL(THDPByteTensorClass = PyMapping_GetItemString(module_dict, (char*)"ByteTensor")); + + return true; +#undef ASSERT_NOT_NULL +} + static std::unordered_map obj2reduceop; static std::unordered_map obj2group; @@ -230,13 +255,17 @@ PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args) } PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) { - if (PyTuple_GET_SIZE(args) != 2) { - THPUtils_invalidArguments(args, "initExtension", 1, "(reduce_op obj, group obj)"); + if (PyTuple_GET_SIZE(args) != 3) { + THPUtils_invalidArguments(args, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)"); return NULL; } - PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 0); - PyObject* group_obj = PyTuple_GET_ITEM(args, 1); + PyObject* is_master_worker_obj = PyTuple_GET_ITEM(args, 0); + PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 1); + PyObject* group_obj = PyTuple_GET_ITEM(args, 2); + + THPUtils_assert(PyBool_Check(is_master_worker_obj), "first argument should be a bool"); + bool is_master_worker = is_master_worker_obj == Py_True; THPObjectPtr reduce_op; #define REGISTER_REDUCE_OP(NAME) \ @@ -256,6 +285,13 @@ PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) { obj2group.emplace(group.get(), THDGroup##NAME); REGISTER_GROUP(WORLD); #undef REGISTER_GROUP + + if (is_master_worker) { + PyObject *module = PyImport_ImportModule("torch.distributed"); + THPUtils_assert(module, "class loader couldn't access torch.distributed module"); + PyObject* module_dict = PyModule_GetDict(module); + if (!THDPModule_loadClasses(module_dict)) return NULL; + } Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/Storage.h b/torch/csrc/distributed/Storage.h index 80e1fe835bb829..5639efba2e3e52 100644 --- a/torch/csrc/distributed/Storage.h +++ b/torch/csrc/distributed/Storage.h @@ -33,7 +33,7 @@ #ifdef _THP_CORE #define THDPStorageType TH_CONCAT_3(THDP,Real,StorageType) -#define THDPStorageBaseStr TH_CONCAT_STRING_3(Cuda,Real,StorageBase) +#define THDPStorageBaseStr TH_CONCAT_STRING_3(Distributed,Real,StorageBase) #endif #include "override_macros.h" diff --git a/torch/csrc/distributed/THDP.h b/torch/csrc/distributed/THDP.h index 1d19dcc430eec4..e3f18f20908315 100644 --- a/torch/csrc/distributed/THDP.h +++ b/torch/csrc/distributed/THDP.h @@ -7,6 +7,8 @@ #include "Module.h" #include "Storage.h" #include "Tensor.h" +#ifdef _THP_CORE #include "utils.h" +#endif #endif diff --git a/torch/csrc/distributed/Tensor.h b/torch/csrc/distributed/Tensor.h index fdc2238f69c076..4b3aa2218d9aed 100644 --- a/torch/csrc/distributed/Tensor.h +++ b/torch/csrc/distributed/Tensor.h @@ -25,7 +25,7 @@ #ifdef _THP_CORE #define THDPTensorType TH_CONCAT_3(THDP,Real,TensorType) -#define THDPTensorBaseStr TH_CONCAT_STRING_3(THDP,Real,TensorBase) +#define THDPTensorBaseStr TH_CONCAT_STRING_3(Distributed,Real,TensorBase) #define THDPTensor_stateless_(NAME) TH_CONCAT_4(THDP,Real,Tensor_stateless_,NAME) #define THDPTensorStatelessType TH_CONCAT_3(THDP,Real,TensorStatelessType) #define THDPTensorStateless TH_CONCAT_3(THDP,Real,TensorStateless) @@ -34,7 +34,7 @@ #include "override_macros.h" -#define TH_GENERIC_FILE "torch/csrc/generic/Tensor.h" -#include +#define THD_GENERIC_FILE "torch/csrc/generic/Tensor.h" +#include #endif diff --git a/torch/csrc/distributed/override_macros.h b/torch/csrc/distributed/override_macros.h index 7b2f13b8aac73b..bc94d056174189 100644 --- a/torch/csrc/distributed/override_macros.h +++ b/torch/csrc/distributed/override_macros.h @@ -30,6 +30,7 @@ #define LIBRARY_STATE_NOARGS #define LIBRARY_STATE +#define TH_GENERIC_FILE THD_GENERIC_FILE #define THHostTensor TH_CONCAT_3(TH,Real,Tensor) #define THHostTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME) diff --git a/torch/csrc/distributed/utils.cpp b/torch/csrc/distributed/utils.cpp index 7a67aab116226b..623ec69d44944e 100644 --- a/torch/csrc/distributed/utils.cpp +++ b/torch/csrc/distributed/utils.cpp @@ -1,10 +1,14 @@ #include -#include "utils.h" +#include "THDP.h" + +#include "override_macros.h" template<> void THPPointer::free() { if (ptr) THDTensorDescriptor_free(ptr); } - template class THPPointer; + +#define THD_GENERIC_FILE "torch/csrc/generic/utils.cpp" +#include diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 6a6282a2ac63de..8376592a8fa0f1 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -17,7 +17,7 @@ def init_process_group(backend): _initialized = True import torch.distributed.collectives as collectives extend_scope(collectives) - assert torch._C._dist_init_extension(reduce_op) + assert torch._C._dist_init_extension(False, reduce_op, group) def init_master_worker(backend): @@ -27,8 +27,8 @@ def init_master_worker(backend): torch._C._dist_init_master_worker(backend) _initialized = True import torch.distributed.collectives as collectives - # import torch.distributed.remote_types as remote_types + import torch.distributed.remote_types as remote_types extend_scope(collectives) - # extend_scope(remote_types) - assert torch._C._dist_init_extension(reduce_op) + extend_scope(remote_types) + assert torch._C._dist_init_extension(True, reduce_op, group) diff --git a/torch/distributed/remote_types.py b/torch/distributed/remote_types.py index 183544e300f4d8..ca5e3fef7976d9 100644 --- a/torch/distributed/remote_types.py +++ b/torch/distributed/remote_types.py @@ -32,8 +32,8 @@ class CharStorage(_DistributedBase, torch._C.DistributedCharStorageBase, _Storag pass class ByteStorage(_DistributedBase, torch._C.DistributedByteStorageBase, _StorageBase): pass -class HalfStorage(_DistributedBase, torch._C.DistributedHalfStorageBase, _StorageBase): - pass +# class HalfStorage(_DistributedBase, torch._C.DistributedHalfStorageBase, _StorageBase): + # pass class DoubleTensor(_DistributedBase, torch._C.DistributedDoubleTensorBase, _TensorBase): def is_signed(self): @@ -78,12 +78,12 @@ def is_signed(self): @classmethod def storage_type(cls): return ByteStorage -class HalfTensor(_DistributedBase, torch._C.DistributedHalfTensorBase, _TensorBase): - def is_signed(self): - return True - @classmethod - def storage_type(): - return HalfStorage +# class HalfTensor(_DistributedBase, torch._C.DistributedHalfTensorBase, _TensorBase): + # def is_signed(self): + # return True + # @classmethod + # def storage_type(): + # return HalfStorage torch._storage_classes.add(DoubleStorage) torch._storage_classes.add(FloatStorage) diff --git a/torch/lib/THD/CMakeLists.txt b/torch/lib/THD/CMakeLists.txt index 244702fa43a758..d444e7639d8ba0 100644 --- a/torch/lib/THD/CMakeLists.txt +++ b/torch/lib/THD/CMakeLists.txt @@ -1,6 +1,24 @@ -CMAKE_MINIMUM_REQUIRED(VERSION 2.6) +CMAKE_MINIMUM_REQUIRED(VERSION 3.0) SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) +################################################################################ +# Helper functions +################################################################################ + +FUNCTION(EXCLUDE_DIR list_name dir_name) + # A helper that excludes all files that contain dir_name in their file path + SET(local_list ${${list_name}}) + FOREACH(source ${local_list}) + IF(${source} MATCHES ${dir_name}) + MESSAGE(STATUS "Excluding " ${source} " from the build") + LIST(REMOVE_ITEM local_list ${source}) + ENDIF() + ENDFOREACH() + SET(${list_name} ${local_list} PARENT_SCOPE) +ENDFUNCTION() + +################################################################################ + FIND_PACKAGE(ZMQ REQUIRED) FIND_PACKAGE(CPPZMQ REQUIRED) @@ -34,9 +52,13 @@ IF(NOT MPI_FOUND) LIST(REMOVE_ITEM test_cpp "${CMAKE_CURRENT_SOURCE_DIR}/test/data_channel_mpi_smoke.cpp") ENDIF() +EXCLUDE_DIR(master_worker_cpp ".*/dispatch/.*\\.cpp$") + SET(all_cpp ${base_cpp} ${process_group_cpp} ${master_worker_cpp}) SET(all_h THD.h ${base_h} ${process_group_h} ${master_worker_h}) +EXCLUDE_DIR(all_cpp ".*/generic/.*\\.cpp$") + INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) ADD_LIBRARY(THD SHARED ${all_cpp}) SET_PROPERTY(TARGET THD PROPERTY CXX_STANDARD 11) @@ -80,6 +102,5 @@ INSTALL(TARGETS THD FOREACH(HEADER ${all_h}) STRING(REGEX MATCH "(.*)[/\\]" DIR ${HEADER}) - MESSAGE(STATUS ${DIR}) INSTALL(FILES ${HEADER} DESTINATION ${THD_INSTALL_INCLUDE_DIR}/THD/${DIR}) ENDFOREACH() diff --git a/torch/lib/THD/base/Storage.hpp b/torch/lib/THD/base/Storage.hpp new file mode 100644 index 00000000000000..58d043393c6649 --- /dev/null +++ b/torch/lib/THD/base/Storage.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "Type.hpp" + +#include +#include +#include +#include +#include +#include + +namespace thd { + +struct Storage { + Storage() {}; + Storage(const Storage& other) = delete; + Storage(Storage&& other) = delete; + virtual ~Storage() {}; + + virtual std::size_t elementSize() const = 0; + virtual std::size_t size() const = 0; + virtual void* data() = 0; + virtual const void* data() const = 0; + virtual Storage& retain() = 0; + virtual Storage& free() = 0; + + virtual Storage& resize(long new_size) = 0; + + virtual thd::Type type() const = 0; +}; + +template +struct StorageScalarInterface : public Storage { + using scalar_type = real; + virtual StorageScalarInterface& fill(scalar_type value) = 0; +}; + +using FloatStorage = StorageScalarInterface; +using IntStorage = StorageScalarInterface; + +} // namespace thd + diff --git a/torch/lib/THD/base/TensorTraits.hpp b/torch/lib/THD/base/TensorTraits.hpp deleted file mode 100644 index d97d47ff0ecf2c..00000000000000 --- a/torch/lib/THD/base/TensorTraits.hpp +++ /dev/null @@ -1,127 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include "master_worker/master/THDTensor.h" -#include "Tensor.hpp" -#include "Type.hpp" - -namespace thd { - -template -struct tensor_interface_traits { - using scalar_type = typename std::conditional< - std::is_floating_point::value, - double, - long long>::type; - using interface_type = TensorScalarInterface; -}; - -template -struct tensor_type_traits {}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::CHAR; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::UCHAR; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::FLOAT; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::DOUBLE; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::SHORT; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::USHORT; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::INT; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::UINT; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::LONG; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::ULONG; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::LONG_LONG; -}; - -template<> -struct tensor_type_traits { - static constexpr Type type = Type::ULONG_LONG; -}; - -template -struct or_trait : std::false_type {}; - -template -struct or_trait : T {}; - -template -struct or_trait - : std::conditional>::type {}; - -template -struct is_any_of : std::false_type {}; - -template -struct is_any_of> : std::is_same {}; - -template -struct is_any_of> - : or_trait, is_any_of>> {}; - -using THDTensorTypes = std::tuple< - THDByteTensor, - THDCharTensor, - THDShortTensor, - THDIntTensor, - THDLongTensor, - THDFloatTensor, - THDDoubleTensor ->; - -template -struct map_to_ptr {}; - -template -struct map_to_ptr> { - using type = std::tuple::type...>; -}; - -using THDTensorPtrTypes = map_to_ptr::type; - -} // namespace thd diff --git a/torch/lib/THD/base/Traits.hpp b/torch/lib/THD/base/Traits.hpp new file mode 100644 index 00000000000000..826dcf00451c3e --- /dev/null +++ b/torch/lib/THD/base/Traits.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "Storage.hpp" +#include "Tensor.hpp" +#include "Type.hpp" + +namespace thd { + +template +struct interface_traits { + using scalar_type = typename std::conditional< + std::is_floating_point::value, + double, + long long>::type; + using tensor_interface_type = TensorScalarInterface; + using storage_interface_type = StorageScalarInterface; +}; + +template +struct type_traits {}; + +template<> +struct type_traits { + static constexpr Type type = Type::CHAR; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::UCHAR; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::FLOAT; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::DOUBLE; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::SHORT; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::USHORT; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::INT; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::UINT; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::LONG; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::ULONG; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::LONG_LONG; +}; + +template<> +struct type_traits { + static constexpr Type type = Type::ULONG_LONG; +}; + + +} // namespace thd diff --git a/torch/lib/THD/base/storages/THStorage.cpp b/torch/lib/THD/base/storages/THStorage.cpp new file mode 100644 index 00000000000000..39aa29ed31da17 --- /dev/null +++ b/torch/lib/THD/base/storages/THStorage.cpp @@ -0,0 +1,11 @@ +#include "THStorage.hpp" +#include "../Traits.hpp" + + +namespace thd { + +#include "generic/THStorage.cpp" +#include + +} // namespace thd + diff --git a/torch/lib/THD/base/storages/THStorage.hpp b/torch/lib/THD/base/storages/THStorage.hpp new file mode 100644 index 00000000000000..0b6270d92b7f96 --- /dev/null +++ b/torch/lib/THD/base/storages/THStorage.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include + +// We're defining THStorage as a custom class +#undef THStorage +#define THRealStorage TH_CONCAT_3(TH,Real,Storage) + +#include "../Storage.hpp" +#include "../Traits.hpp" + +namespace thd { + +template +struct th_storage_traits {}; + +#include "base/storages/generic/THStorage.hpp" +#include + + +template +struct THStorage : public interface_traits::storage_interface_type { +private: + using interface_type = typename interface_traits::storage_interface_type; +public: + using storage_type = typename th_storage_traits::storage_type; + using scalar_type = typename interface_type::scalar_type; + + THStorage(); + THStorage(storage_type *wrapped); + THStorage(std::size_t size); + virtual ~THStorage(); + + virtual std::size_t elementSize() const override; + virtual std::size_t size() const override; + virtual void* data() override; + virtual const void* data() const override; + virtual THStorage& retain() override; + virtual THStorage& free() override; + + virtual THStorage& resize(long new_size) override; + virtual THStorage& fill(scalar_type value) override; + + virtual thd::Type type() const override; + +protected: + storage_type *storage; +}; + +} // namespace thd + diff --git a/torch/lib/THD/base/storages/generic/THStorage.cpp b/torch/lib/THD/base/storages/generic/THStorage.cpp new file mode 100644 index 00000000000000..95a55dbef14131 --- /dev/null +++ b/torch/lib/THD/base/storages/generic/THStorage.cpp @@ -0,0 +1,69 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "base/storages/generic/THStorage.cpp" +#else + +template<> +THStorage::THStorage(): storage(THStorage_(new)()) {} + +template<> +THStorage::THStorage(storage_type* storage): storage(storage) {} + +template<> +THStorage::THStorage(std::size_t storage_size) + : storage(THStorage_(newWithSize)(storage_size)) {} + +template<> +THStorage::~THStorage() { + THStorage_(free)(storage); +} + +template<> +std::size_t THStorage::elementSize() const { + return sizeof(real); +} + +template<> +std::size_t THStorage::size() const { + return storage->size; +} + +template<> +void* THStorage::data() { + return storage->data; +} + +template<> +const void* THStorage::data() const { + return storage->data; +} + +template<> +auto THStorage::retain() -> THStorage& { + THStorage_(retain)(storage); + return *this; +} + +template<> +auto THStorage::free() -> THStorage& { + THStorage_(free)(storage); + return *this; +} + +template<> +auto THStorage::resize(long new_size) -> THStorage& { + THStorage_(resize)(storage, new_size); + return *this; +} + +template<> +auto THStorage::fill(scalar_type value) -> THStorage& { + THStorage_(fill)(storage, value); + return *this; +} + +template<> +thd::Type THStorage::type() const { + return thd::type_traits::type; +} + +#endif diff --git a/torch/lib/THD/base/storages/generic/THStorage.hpp b/torch/lib/THD/base/storages/generic/THStorage.hpp new file mode 100644 index 00000000000000..39523a1b44177d --- /dev/null +++ b/torch/lib/THD/base/storages/generic/THStorage.hpp @@ -0,0 +1,10 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "base/storages/generic/THStorage.hpp" +#else + +template<> +struct th_storage_traits { + using storage_type = THRealStorage; +}; + +#endif diff --git a/torch/lib/THD/base/tensors/THTensor.cpp b/torch/lib/THD/base/tensors/THTensor.cpp index 87382189d80ab5..7a7e4d29a294ea 100644 --- a/torch/lib/THD/base/tensors/THTensor.cpp +++ b/torch/lib/THD/base/tensors/THTensor.cpp @@ -1,5 +1,5 @@ #include "THTensor.hpp" -#include "../TensorTraits.hpp" +#include "../Traits.hpp" namespace thd { diff --git a/torch/lib/THD/base/tensors/THTensor.hpp b/torch/lib/THD/base/tensors/THTensor.hpp index 218d92a7d1d5d1..7f651799b622c8 100644 --- a/torch/lib/THD/base/tensors/THTensor.hpp +++ b/torch/lib/THD/base/tensors/THTensor.hpp @@ -7,23 +7,23 @@ #define THRealTensor TH_CONCAT_3(TH,Real,Tensor) #include "../Tensor.hpp" -#include "../TensorTraits.hpp" +#include "../Traits.hpp" namespace thd { template -struct th_traits {}; +struct th_tensor_traits {}; #include "base/tensors/generic/THTensor.hpp" #include template -struct THTensor : public tensor_interface_traits::interface_type { +struct THTensor : public interface_traits::tensor_interface_type { private: - using interface_type = typename tensor_interface_traits::interface_type; + using interface_type = typename interface_traits::tensor_interface_type; public: - using tensor_type = typename th_traits::tensor_type; + using tensor_type = typename th_tensor_traits::tensor_type; using scalar_type = typename interface_type::scalar_type; using long_range = Tensor::long_range; diff --git a/torch/lib/THD/base/tensors/generic/THTensor.cpp b/torch/lib/THD/base/tensors/generic/THTensor.cpp index dc3438f9537995..ecbbe3359dc2ff 100644 --- a/torch/lib/THD/base/tensors/generic/THTensor.cpp +++ b/torch/lib/THD/base/tensors/generic/THTensor.cpp @@ -128,7 +128,7 @@ auto THTensor::add(const Tensor &source, scalar_type value) -> THTensor& { template<> thd::Type THTensor::type() const { - return thd::tensor_type_traits::type; + return thd::type_traits::type; } #endif diff --git a/torch/lib/THD/base/tensors/generic/THTensor.hpp b/torch/lib/THD/base/tensors/generic/THTensor.hpp index 8a8ca779fca339..3a3aa1fd51b89b 100644 --- a/torch/lib/THD/base/tensors/generic/THTensor.hpp +++ b/torch/lib/THD/base/tensors/generic/THTensor.hpp @@ -3,7 +3,7 @@ #else template<> -struct th_traits { +struct th_tensor_traits { using tensor_type = THRealTensor; }; diff --git a/torch/lib/THD/master_worker/common/CommandChannel.cpp b/torch/lib/THD/master_worker/common/CommandChannel.cpp index f0c36616c6b915..074c153cf97737 100644 --- a/torch/lib/THD/master_worker/common/CommandChannel.cpp +++ b/torch/lib/THD/master_worker/common/CommandChannel.cpp @@ -27,9 +27,9 @@ void sendMessage(std::unique_ptr msg, zmq::socket_t& socket) { msg.release())); } -std::unique_ptr recvMessage(zmq::socket_t& socket) { +std::unique_ptr recvMessage(zmq::socket_t& socket, bool block = true) { zmq::message_t zmsg; - if (socket.recv(&zmsg, ZMQ_DONTWAIT) == false) { + if (socket.recv(&zmsg, block ? 0 : ZMQ_DONTWAIT) == false) { return nullptr; } else { // XXX: Excesive copying here! I'm not sure how to avoid it. @@ -139,7 +139,7 @@ void MasterCommandChannel::sendMessage(std::unique_ptr msg, } std::unique_ptr MasterCommandChannel::recvMessage(int rank) { - return thd::recvMessage(_pull_sockets.at(rank)); + return thd::recvMessage(_pull_sockets.at(rank), false); } // TODO: Validate this environmental variable. diff --git a/torch/lib/THD/master_worker/common/Functions.hpp b/torch/lib/THD/master_worker/common/Functions.hpp index 45e513af018e85..4c4c894d64f229 100644 --- a/torch/lib/THD/master_worker/common/Functions.hpp +++ b/torch/lib/THD/master_worker/common/Functions.hpp @@ -6,9 +6,36 @@ namespace thd { enum Functions: std::uint16_t { construct, - constructWithSize, + constructWithSize, + free, + resize, + resizeAs, + resize1d, + resize2d, + resize3d, + resize4d, + resize5d, + set, + setStorage, + setStorage1d, + setStorage2d, + setStorage3d, + setStorage4d, + narrow, + select, add, - free + fill, + + // storage functions + storageConstruct, + storageConstructWithSize, + storageFree, + storageResize, + storageSet, + + // communication requests + sendTensor, + sendStorage, }; } // namespace thd diff --git a/torch/lib/THD/master_worker/common/RPC-inl.hpp b/torch/lib/THD/master_worker/common/RPC-inl.hpp index 2fd6d2319b8e6a..962a3e2203678b 100644 --- a/torch/lib/THD/master_worker/common/RPC-inl.hpp +++ b/torch/lib/THD/master_worker/common/RPC-inl.hpp @@ -1,6 +1,7 @@ #include #include "TH/THStorage.h" -#include "base/TensorTraits.hpp" +#include "base/Traits.hpp" +#include "Traits.hpp" namespace thd { namespace rpc { namespace detail { //////////////////////////////////////////////////////////////////////////////// @@ -19,15 +20,24 @@ inline void _appendType(ByteArray& str, Type _type) { } template -inline void __appendData(ByteArray& str, const T& arg, std::false_type is_tensor) { - _appendType(str, tensor_type_traits::type); +inline void __appendData(ByteArray& str, const T& arg, + std::false_type is_tensor, std::false_type is_storage) { + _appendType(str, type_traits::type); _appendScalar(str, arg); } template -inline void __appendData(ByteArray& str, const T& arg, std::true_type is_tensor) { +inline void __appendData(ByteArray& str, const T& arg, + std::true_type is_tensor, std::false_type is_storage) { _appendType(str, Type::TENSOR); - _appendScalar(str, arg->tensor_id); + _appendScalar(str, arg->tensor_id); +} + +template +inline void __appendData(ByteArray& str, const T& arg, + std::false_type is_tensor, std::true_type is_storage) { + _appendType(str, Type::STORAGE); + _appendScalar(str, arg->storage_id); } template @@ -35,7 +45,8 @@ inline void _appendData(ByteArray& str, const T& arg) { __appendData( str, arg, - is_any_of() + is_any_of(), + is_any_of() ); } diff --git a/torch/lib/THD/master_worker/common/RPC.cpp b/torch/lib/THD/master_worker/common/RPC.cpp index 29069ad487c35b..d815e12ebf8d2d 100644 --- a/torch/lib/THD/master_worker/common/RPC.cpp +++ b/torch/lib/THD/master_worker/common/RPC.cpp @@ -104,21 +104,22 @@ long long unpackInteger(RPCMessage& raw_message) { else if (type == Type::LONG_LONG) return unpackScalar(raw_message); - throw std::invalid_argument("wrong integer type in the raw message"); + throw std::invalid_argument(std::string("wrong integer type in the raw message (") + + std::to_string(static_cast(type)) + ")"); } -Tensor *unpackTensor(RPCMessage& raw_message) { +object_id_type unpackTensor(RPCMessage& raw_message) { Type type = unpackType(raw_message); if (type == Type::TENSOR) - return NULL; //unpackScalar(raw_message); TODO + return unpackScalar(raw_message); throw std::invalid_argument("expected tensor in the raw message"); } -tensor_id_type unpackTensorAsId(RPCMessage& raw_message) { +object_id_type unpackStorage(RPCMessage& raw_message) { Type type = unpackType(raw_message); - if (type == Type::TENSOR) - return unpackScalar(raw_message); - throw std::invalid_argument("expected tensor in the raw message"); + if (type == Type::STORAGE) + return unpackScalar(raw_message); + throw std::invalid_argument("expected storage in the raw message"); } THLongStorage* unpackTHLongStorage(RPCMessage& raw_message) { diff --git a/torch/lib/THD/master_worker/common/RPC.hpp b/torch/lib/THD/master_worker/common/RPC.hpp index 9da1484fa655db..dc0ca8b3982657 100644 --- a/torch/lib/THD/master_worker/common/RPC.hpp +++ b/torch/lib/THD/master_worker/common/RPC.hpp @@ -1,6 +1,7 @@ #pragma once #include "../../base/Tensor.hpp" +#include "../../base/Storage.hpp" #include "../master/THDTensor.h" #include "ByteArray.hpp" #include "TH/THStorage.h" @@ -11,7 +12,7 @@ namespace thd { -using tensor_id_type = std::uint64_t; +using object_id_type = std::uint64_t; namespace rpc { @@ -47,8 +48,8 @@ Type unpackType(RPCMessage& raw_message); double unpackFloat(RPCMessage& raw_message); std::uint16_t unpackFunctionId(RPCMessage& raw_message); long long unpackInteger(RPCMessage& raw_message); -Tensor* unpackTensor(RPCMessage& raw_message); -tensor_id_type unpackTensorAsId(RPCMessage& raw_message); +object_id_type unpackTensor(RPCMessage& raw_message); +object_id_type unpackStorage(RPCMessage& raw_message); THLongStorage* unpackTHLongStorage(RPCMessage& raw_message); }} // namespace rpc, thd diff --git a/torch/lib/THD/master_worker/common/Traits.hpp b/torch/lib/THD/master_worker/common/Traits.hpp new file mode 100644 index 00000000000000..3e97e9408ebdc0 --- /dev/null +++ b/torch/lib/THD/master_worker/common/Traits.hpp @@ -0,0 +1,60 @@ +#include +#include + +#include "master_worker/master/THDTensor.h" +#include "master_worker/master/THDStorage.h" + +namespace thd { + +template +struct or_trait : std::false_type {}; + +template +struct or_trait : T {}; + +template +struct or_trait + : std::conditional>::type {}; + +template +struct is_any_of : std::false_type {}; + +template +struct is_any_of> : std::is_same {}; + +template +struct is_any_of> + : or_trait, is_any_of>> {}; + +using THDTensorTypes = std::tuple< + THDByteTensor, + THDCharTensor, + THDShortTensor, + THDIntTensor, + THDLongTensor, + THDFloatTensor, + THDDoubleTensor +>; + +using THDStorageTypes = std::tuple< + THDByteStorage, + THDCharStorage, + THDShortStorage, + THDIntStorage, + THDLongStorage, + THDFloatStorage, + THDDoubleStorage +>; + +template +struct map_to_ptr {}; + +template +struct map_to_ptr> { + using type = std::tuple::type...>; +}; + +using THDTensorPtrTypes = map_to_ptr::type; +using THDStoragePtrTypes = map_to_ptr::type; + +} // namespace thd diff --git a/torch/lib/THD/master_worker/master/Master.cpp b/torch/lib/THD/master_worker/master/Master.cpp index 6f02c637c0d3b5..6ecc8a9bcadb05 100644 --- a/torch/lib/THD/master_worker/master/Master.cpp +++ b/torch/lib/THD/master_worker/master/Master.cpp @@ -25,6 +25,7 @@ bool THDMasterWorkerInit(THDChannelType channel_type) { } // TODO: initialize master + thd::master::masterCommandChannel.reset(new MasterCommandChannel()); return true; } diff --git a/torch/lib/THD/master_worker/master/THDStorage.cpp b/torch/lib/THD/master_worker/master/THDStorage.cpp new file mode 100644 index 00000000000000..6be74716356fa0 --- /dev/null +++ b/torch/lib/THD/master_worker/master/THDStorage.cpp @@ -0,0 +1,12 @@ +#include "THD.h" +#include "base/Traits.hpp" +#include "State.hpp" +#include "master_worker/common/RPC.hpp" +#include "master_worker/common/Functions.hpp" +#include "master_worker/master/Master.hpp" + +#include + +#include "master_worker/master/generic/THDStorage.cpp" +#include "TH/THGenerateAllTypes.h" + diff --git a/torch/lib/THD/master_worker/master/THDTensor.cpp b/torch/lib/THD/master_worker/master/THDTensor.cpp index 17db39c3d52954..6f3c0b02a02081 100644 --- a/torch/lib/THD/master_worker/master/THDTensor.cpp +++ b/torch/lib/THD/master_worker/master/THDTensor.cpp @@ -1,5 +1,5 @@ #include "THDTensor.h" -#include "base/TensorTraits.hpp" +#include "base/Traits.hpp" #include "State.hpp" #include "master_worker/common/RPC.hpp" #include "master_worker/common/Functions.hpp" diff --git a/torch/lib/THD/master_worker/master/generic/THDStorage.cpp b/torch/lib/THD/master_worker/master/generic/THDStorage.cpp index d3315fce2f5a15..328d775ed1d7db 100644 --- a/torch/lib/THD/master_worker/master/generic/THDStorage.cpp +++ b/torch/lib/THD/master_worker/master/generic/THDStorage.cpp @@ -7,51 +7,103 @@ using namespace rpc; using namespace master; static THDStorage* THDStorage_(_alloc)() { - THDStorage* new_tensor = new THDStorage(); - new_tensor->storage_id = THDState::s_nextId++; - new_tensor->node_id = THDState::s_current_worker; - return new_tensor; + THDStorage* new_storage = new THDStorage(); + std::memset(reinterpret_cast(new_storage), 0, sizeof(new_storage)); + new_storage->refcount = 1; + new_storage->storage_id = THDState::s_nextId++; + new_storage->node_id = THDState::s_current_worker; + new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + return new_storage; } THDStorage* THDStorage_(new)() { - THDStorage* tensor = THDStorage_(_alloc)(); - std::unique_ptr construct_message = packMessage( - Functions::construct, - static_cast(tensor_type_traits::type), - *tensor - ); + THDStorage* storage = THDStorage_(_alloc)(); + Type type = type_traits::type; masterCommandChannel->sendMessage( - std::move(construct_message), - THDState::s_current_worker + packMessage( + Functions::storageConstruct, + type, + storage + ), + THDState::s_current_worker ); - return tensor; + return storage; } -THDStorage* THDStorage_(newWithSize)(THLongStorage *sizes, THLongStorage *strides) { - THDStorage* tensor = THDStorage_(_alloc)(); - std::unique_ptr construct_message = packMessage( - Functions::constructWithSize, - static_cast(tensor_type_traits::type), - *tensor, - sizes, - strides - ); +void THDStorage_(resize)(THDStorage *storage, ptrdiff_t size) +{ + if(!(storage->flag & TH_STORAGE_RESIZABLE)) + THError("Trying to resize storage that is not resizable"); + + storage->size = size; masterCommandChannel->sendMessage( - std::move(construct_message), - THDState::s_current_worker + packMessage( + Functions::storageResize, + storage + ), + THDState::s_current_worker ); - return tensor; } -void THDStorage_(free)(THDStorage *tensor) { - std::unique_ptr free_message = packMessage( - Functions::free, - tensor->tensor_id +void THDStorage_(free)(THDStorage *storage) +{ + if(!storage || !(storage->flag & TH_STORAGE_REFCOUNTED)) return; + + if (THAtomicDecrementRef(&storage->refcount)) { + masterCommandChannel->sendMessage( + packMessage( + Functions::storageFree, + storage + ), + THDState::s_current_worker + ); + + if(storage->flag & TH_STORAGE_VIEW) + THDStorage_(free)(storage->view); + delete storage; + } +} + +void THDStorage_(retain)(THDStorage *storage) { + if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) + THAtomicIncrementRef(&storage->refcount); +} + +ptrdiff_t THDStorage_(size)(const THDStorage* storage) { + return storage->size; +} + +THDStorage* THDStorage_(newWithSize)(ptrdiff_t size) { + Type type = type_traits::type; + THDStorage *storage = THDStorage_(_alloc)(); + storage->size = size; + masterCommandChannel->sendMessage( + packMessage( + Functions::storageConstructWithSize, + type, + storage, + size + ), + THDState::s_current_worker ); + return storage; +} + +void THDStorage_(set)(THDStorage* storage, ptrdiff_t offset, real value) { masterCommandChannel->sendMessage( - std::move(free_message), - THDState::s_current_worker + packMessage( + Functions::storageSet, + storage, + offset, + value + ), + THDState::s_current_worker ); } +real THDStorage_(get)(const THDStorage* storage, ptrdiff_t offset) { + THError("get not supported yet"); + return 0; +} + #endif diff --git a/torch/lib/THD/master_worker/master/generic/THDStorage.h b/torch/lib/THD/master_worker/master/generic/THDStorage.h index 54f552d2f42f75..e6698ad737bd5c 100644 --- a/torch/lib/THD/master_worker/master/generic/THDStorage.h +++ b/torch/lib/THD/master_worker/master/generic/THDStorage.h @@ -2,12 +2,16 @@ #define TH_GENERIC_FILE "master_worker/master/generic/THDStorage.h" #else -typedef struct { +typedef struct THDStorage { unsigned long long storage_id; ptrdiff_t size; int refcount; char flag; - // TODO: what about allocators? + // these are here only so that the struct has a similar structure to TH + void* allocator; + void* allocatorContext; + struct THDStorage *view; + // Additional fields int node_id; int device_id; // unused at the moment } THDStorage; diff --git a/torch/lib/THD/master_worker/master/generic/THDTensor.cpp b/torch/lib/THD/master_worker/master/generic/THDTensor.cpp index 6a1f8931a39dac..1c20c90f8953cf 100644 --- a/torch/lib/THD/master_worker/master/generic/THDTensor.cpp +++ b/torch/lib/THD/master_worker/master/generic/THDTensor.cpp @@ -6,16 +6,72 @@ using namespace thd; using namespace rpc; using namespace master; +THDStorage *THDTensor_(storage)(const THDTensor *self) +{ + return self->storage; +} + +ptrdiff_t THDTensor_(storageOffset)(const THDTensor *self) +{ + return self->storageOffset; +} + +int THDTensor_(nDimension)(const THDTensor *self) +{ + return self->nDimension; +} + +long THDTensor_(size)(const THDTensor *self, int dim) +{ + THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor", + dim+1, THDTensor_(nDimension)(self)); + return self->size[dim]; +} + +long THDTensor_(stride)(const THDTensor *self, int dim) +{ + THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor", dim+1, + THDTensor_(nDimension)(self)); + return self->stride[dim]; +} + +THLongStorage *THDTensor_(newSizeOf)(THDTensor *self) +{ + THLongStorage *size = THLongStorage_newWithSize(self->nDimension); + THLongStorage_rawCopy(size, self->size); + return size; +} + +THLongStorage *THDTensor_(newStrideOf)(THDTensor *self) +{ + THLongStorage *stride = THLongStorage_newWithSize(self->nDimension); + THLongStorage_rawCopy(stride, self->stride); + return stride; +} + +void THDTensor_(setFlag)(THDTensor *self, const char flag) +{ + self->flag |= flag; +} + +void THDTensor_(clearFlag)(THDTensor *self, const char flag) +{ + self->flag &= ~flag; +} + static THDTensor* THDTensor_(_alloc)() { THDTensor* new_tensor = new THDTensor(); + std::memset(reinterpret_cast(new_tensor), 0, sizeof(THDTensor)); new_tensor->tensor_id = THDState::s_nextId++; new_tensor->refcount = 1; + new_tensor->flag = TH_TENSOR_REFCOUNTED; + // TODO: allocate storage return new_tensor; } THDTensor* THDTensor_(new)() { THDTensor* tensor = THDTensor_(_alloc)(); - Type constructed_type = tensor_type_traits::type; + Type constructed_type = type_traits::type; masterCommandChannel->sendMessage( packMessage( Functions::construct, @@ -29,7 +85,7 @@ THDTensor* THDTensor_(new)() { THDTensor* THDTensor_(newWithSize)(THLongStorage *sizes, THLongStorage *strides) { THDTensor* tensor = THDTensor_(_alloc)(); - Type constructed_type = tensor_type_traits::type; + Type constructed_type = type_traits::type; masterCommandChannel->sendMessage( packMessage( Functions::constructWithSize, @@ -43,16 +99,476 @@ THDTensor* THDTensor_(newWithSize)(THLongStorage *sizes, THLongStorage *strides) return tensor; } +// taken from TH (generic/THTensor.c) +static void THDTensor_(_resize)(THDTensor *self, int nDimension, long *size, long *stride) +{ + int d; + int nDimension_; + ptrdiff_t totalSize; + int hascorrectsize = 1; + + nDimension_ = 0; + for(d = 0; d < nDimension; d++) + { + if(size[d] > 0) + { + nDimension_++; + if((self->nDimension > d) && (size[d] != self->size[d])) + hascorrectsize = 0; + + if((self->nDimension > d) && stride && (stride[d] >= 0) && (stride[d] != self->stride[d])) + hascorrectsize = 0; + } + else + break; + } + nDimension = nDimension_; + + if(nDimension != self->nDimension) + hascorrectsize = 0; + + if(hascorrectsize) + return; + + if(nDimension > 0) + { + if(nDimension != self->nDimension) + { + self->size = reinterpret_cast(THRealloc(self->size, sizeof(long)*nDimension)); + self->stride = reinterpret_cast(THRealloc(self->stride, sizeof(long)*nDimension)); + self->nDimension = nDimension; + } + + totalSize = 1; + for(d = self->nDimension-1; d >= 0; d--) + { + self->size[d] = size[d]; + if(stride && (stride[d] >= 0) ) + self->stride[d] = stride[d]; + else + { + if(d == self->nDimension-1) + self->stride[d] = 1; + else + self->stride[d] = self->size[d+1]*self->stride[d+1]; + } + totalSize += (self->size[d]-1)*self->stride[d]; + } + + if(totalSize + self->storageOffset > 0) + { + if(!self->storage) + self->storage = THDStorage_(new)(); + if(totalSize+self->storageOffset > self->storage->size) + THDStorage_(resize)(self->storage, totalSize+self->storageOffset); + } + } + else + self->nDimension = 0; +} + +void THDTensor_(resize)(THDTensor *tensor, THLongStorage *size, THLongStorage *stride) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize, + tensor, + size, + stride + ), + THDState::s_current_worker + ); + THDTensor_(_resize)(tensor, size->size, size->data, stride ? stride->data : nullptr); +} + +void THDTensor_(resizeAs)(THDTensor *tensor, THDTensor *src) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resizeAs, + tensor, + src + ), + THDState::s_current_worker + ); + THDTensor_(_resize)(tensor, src->nDimension, src->size, nullptr); +} + +void THDTensor_(resize1d)(THDTensor *tensor, long size0) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize1d, + tensor, + size0 + ), + THDState::s_current_worker + ); + THDTensor_(_resize)(tensor, 1, &size0, nullptr); +} + +void THDTensor_(resize2d)(THDTensor *tensor, long size0, long size1) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize2d, + tensor, + size0, + size1 + ), + THDState::s_current_worker + ); + long sizes[] = {size0, size1}; + THDTensor_(_resize)(tensor, 2, sizes, nullptr); +} + +void THDTensor_(resize3d)(THDTensor *tensor, long size0, long size1, long size2) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize3d, + tensor, + size0, + size1, + size2 + ), + THDState::s_current_worker + ); + long sizes[] = {size0, size1, size2}; + THDTensor_(_resize)(tensor, 3, sizes, nullptr); +} + +void THDTensor_(resize4d)(THDTensor *tensor, long size0, long size1, long size2, long size3) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize4d, + tensor, + size0, + size1, + size2, + size3 + ), + THDState::s_current_worker + ); + long sizes[] = {size0, size1, size2, size3}; + THDTensor_(_resize)(tensor, 4, sizes, nullptr); +} + +void THDTensor_(resize5d)(THDTensor *tensor, long size0, long size1, long size2, long size3, long size4_) { + masterCommandChannel->sendMessage( + packMessage( + Functions::resize5d, + tensor, + size0, + size1, + size2, + size3, + size4_ + ), + THDState::s_current_worker + ); + long sizes[] = {size0, size1, size2, size3, size4_}; + THDTensor_(_resize)(tensor, 5, sizes, nullptr); +} + +static void THDTensor_(_set)(THDTensor *self, THDStorage *storage, + ptrdiff_t storageOffset, int nDimension, long *size, long *stride) +{ + /* storage */ + if(self->storage != storage) + { + if(self->storage) + THDStorage_(free)(self->storage); + + if(storage) + { + self->storage = storage; + THDStorage_(retain)(self->storage); + } + else + self->storage = NULL; + } + + /* storageOffset */ + if(storageOffset < 0) + THError("can't set negative storage offset"); + self->storageOffset = storageOffset; + + /* size and stride */ + THDTensor_(_resize)(self, nDimension, size, stride); +} + +void THDTensor_(set)(THDTensor *self, THDTensor *src) { + if (self == src) + return; + + masterCommandChannel->sendMessage( + packMessage( + Functions::set, + self, + src + ), + THDState::s_current_worker + ); + THDTensor_(_set)(self, src->storage, src->storageOffset, + src->nDimension, src->size, src->stride); +} + +void THDTensor_(setStorage)(THDTensor *self, THDStorage *storage, + ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride) { + masterCommandChannel->sendMessage( + packMessage( + Functions::setStorage, + self, + storage, + storageOffset, + size, + stride + ), + THDState::s_current_worker + ); + if (size && stride) + THArgCheck(size->size == stride->size, 5, "inconsistent number of sizes and strides"); + + THDTensor_(_set)( + self, + storage, + storageOffset, + (size ? size->size : (stride ? stride->size : 0)), + (size ? size->data : nullptr), + (stride ? stride->data : nullptr) + ); +} + +void THDTensor_(setStorage1d)(THDTensor *self, THDStorage *storage, ptrdiff_t storageOffset, + long size0, long stride0) { + masterCommandChannel->sendMessage( + packMessage( + Functions::setStorage1d, + self, + storage, + storageOffset, + size0, + stride0 + ), + THDState::s_current_worker + ); + long size[] = {size0}; + long stride[] = {stride0}; + THDTensor_(_set)( + self, + storage, + storageOffset, + 1, + size, + stride + ); +} + +void THDTensor_(setStorage2d)(THDTensor *self, THDStorage *storage, ptrdiff_t storageOffset, + long size0, long stride0, + long size1, long stride1) { + masterCommandChannel->sendMessage( + packMessage( + Functions::setStorage2d, + self, + storage, + storageOffset, + size0, + size1, + stride0, + stride1 + ), + THDState::s_current_worker + ); + long size[] = {size0, size1}; + long stride[] = {stride0, stride1}; + THDTensor_(_set)( + self, + storage, + storageOffset, + 2, + size, + stride + ); +} + +void THDTensor_(setStorage3d)(THDTensor *self, THDStorage *storage, ptrdiff_t storageOffset, + long size0, long stride0, + long size1, long stride1, + long size2, long stride2) { + masterCommandChannel->sendMessage( + packMessage( + Functions::setStorage2d, + self, + storage, + storageOffset, + size0, + size1, + size2, + stride0, + stride1, + stride2 + ), + THDState::s_current_worker + ); + long size[] = {size0, size1, size2}; + long stride[] = {stride0, stride1, stride2}; + THDTensor_(_set)( + self, + storage, + storageOffset, + 3, + size, + stride + ); +} +void THDTensor_(setStorage4d)(THDTensor *self, THDStorage *storage, ptrdiff_t storageOffset, + long size0, long stride0, + long size1, long stride1, + long size2, long stride2, + long size3, long stride3) { + masterCommandChannel->sendMessage( + packMessage( + Functions::setStorage2d, + self, + storage, + storageOffset, + size0, + size1, + size2, + size3, + stride0, + stride1, + stride2, + stride3 + ), + THDState::s_current_worker + ); + long size[] = {size0, size1, size2, size3}; + long stride[] = {stride0, stride1, stride2, stride3}; + THDTensor_(_set)( + self, + storage, + storageOffset, + 4, + size, + stride + ); +} + + void THDTensor_(free)(THDTensor *tensor) { // TODO: refcount masterCommandChannel->sendMessage( packMessage( Functions::free, - tensor->tensor_id + tensor ), THDState::s_current_worker ); } +ptrdiff_t THDTensor_(nElement)(const THDTensor *self) +{ + if(self->nDimension == 0) return 0; + + ptrdiff_t nElement = 1; + int d; + for(d = 0; d < self->nDimension; d++) + nElement *= self->size[d]; + return nElement; +} + +void THDTensor_(narrow)(THDTensor *self, THDTensor *src, int dimension, + long firstIndex, long size) { + if(!src) src = self; + + THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range"); + THArgCheck((firstIndex >= 0) && (firstIndex < src->size[dimension]), 3, "out of range"); + THArgCheck((size > 0) && (firstIndex <= src->size[dimension] - size), 4, "out of range"); + + THDTensor_(set)(self, src); + + if(firstIndex > 0) + self->storageOffset += firstIndex*self->stride[dimension]; + + self->size[dimension] = size; + + masterCommandChannel->sendMessage( + packMessage( + Functions::narrow, + self, + src, + dimension, + firstIndex, + size + ), + THDState::s_current_worker + ); +} + +void THDTensor_(select)(THDTensor *self, THDTensor *src, int dimension, + long sliceIndex) { + if(!src) + src = self; + + THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); + THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range"); + THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range"); + + THDTensor_(set)(self, src); + THDTensor_(narrow)(self, NULL, dimension, sliceIndex, 1); + for(int d = dimension; d < self->nDimension-1; d++) { + self->size[d] = self->size[d+1]; + self->stride[d] = self->stride[d+1]; + } + self->nDimension--; + + masterCommandChannel->sendMessage( + packMessage( + Functions::select, + self, + src, + dimension, + sliceIndex + ), + THDState::s_current_worker + ); +} + +THDTensor *THDTensor_(newWithStorage1d)(THDStorage *storage_, + ptrdiff_t storageOffset_, long size0_, long stride0_) { + THError("newWithStorage1d not supported yet"); + return nullptr; +} + +THDTensor *THDTensor_(newWithTensor)(THDTensor *tensor) { + THError("newWithTensor not supported yet"); + return nullptr; +} + + +void THDTensor_(fill)(THDTensor *tensor, real value) { + masterCommandChannel->sendMessage( + packMessage( + Functions::fill, + tensor, + value + ), + THDState::s_current_worker + ); +} + +void THDTensor_(zeros)(THDTensor *tensor, THLongStorage *size) { + THDTensor_(resize)(tensor, size, nullptr); + THDTensor_(fill)(tensor, 0); +} + +void THDTensor_(ones)(THDTensor *tensor, THLongStorage *size) { + THDTensor_(resize)(tensor, size, nullptr); + THDTensor_(fill)(tensor, 0); +} + +ptrdiff_t THDTensor_(numel)(THDTensor *t) { + return THDTensor_(nElement)(t); +} + + #endif diff --git a/torch/lib/THD/master_worker/master/generic/THDTensor.h b/torch/lib/THD/master_worker/master/generic/THDTensor.h index 70543ed0b1077d..d0efec16d986c8 100644 --- a/torch/lib/THD/master_worker/master/generic/THDTensor.h +++ b/torch/lib/THD/master_worker/master/generic/THDTensor.h @@ -3,8 +3,6 @@ #else typedef struct { - unsigned long long tensor_id; - long *size; long *stride; int nDimension; @@ -14,6 +12,9 @@ typedef struct { int refcount; char flag; + + // Additional fields + unsigned long long tensor_id; } THDTensor; /**** access methods ****/ @@ -24,7 +25,6 @@ THD_API long THDTensor_(size)(const THDTensor *self, int dim); THD_API long THDTensor_(stride)(const THDTensor *self, int dim); THD_API THLongStorage *THDTensor_(newSizeOf)(THDTensor *self); THD_API THLongStorage *THDTensor_(newStrideOf)(THDTensor *self); -THD_API real *THDTensor_(data)(const THDTensor *self); THD_API void THDTensor_(setFlag)(THDTensor *self, const char flag); THD_API void THDTensor_(clearFlag)(THDTensor *self, const char flag); diff --git a/torch/lib/THD/master_worker/worker/Dispatch.cpp b/torch/lib/THD/master_worker/worker/Dispatch.cpp index e698ae0d7951e6..2763b3e55c4923 100644 --- a/torch/lib/THD/master_worker/worker/Dispatch.cpp +++ b/torch/lib/THD/master_worker/worker/Dispatch.cpp @@ -6,7 +6,10 @@ #include #include -#include "../../base/TensorTraits.hpp" +#include "../../process_group/General.hpp" +#include "../../base/Tensor.hpp" +#include "../../base/Traits.hpp" +#include "../../base/storages/THStorage.hpp" #include "../../base/tensors/THTensor.hpp" #include "../common/Functions.hpp" #include "../common/RPC.hpp" @@ -18,6 +21,14 @@ namespace worker { namespace detail { +Tensor* unpackRetrieveTensor(rpc::RPCMessage& message) { + return workerTensors.at(unpackTensor(message)).get(); +} + +Storage* unpackRetrieveStorage(rpc::RPCMessage& message) { + return workerStorages.at(unpackStorage(message)).get(); +} + static std::unique_ptr createTensor(Type type) { if (type == Type::UCHAR) return std::unique_ptr(new THTensor()); @@ -36,37 +47,40 @@ static std::unique_ptr createTensor(Type type) { throw std::invalid_argument("passed characted doesn't represent a tensor type"); } -static void construct(rpc::RPCMessage& raw_message) { - // TODO: assert_empty(raw_message) - Type type = rpc::unpackType(raw_message); - thd::tensor_id_type id = rpc::unpackTensorAsId(raw_message); - workerTensors.insert(std::make_pair( - id, - createTensor(type) - )); +static std::unique_ptr createStorage(Type type) { + if (type == Type::UCHAR) + return std::unique_ptr(new THStorage()); + else if (type == Type::CHAR) + return std::unique_ptr(new THStorage()); + else if (type == Type::SHORT) + return std::unique_ptr(new THStorage()); + else if (type == Type::INT) + return std::unique_ptr(new THStorage()); + else if (type == Type::LONG) + return std::unique_ptr(new THStorage()); + else if (type == Type::FLOAT) + return std::unique_ptr(new THStorage()); + else if (type == Type::DOUBLE) + return std::unique_ptr(new THStorage()); + throw std::invalid_argument("passed characted doesn't represent a storage type"); } -static void constructWithSize(rpc::RPCMessage& raw_message) { - // TODO: assert_empty(raw_message) - Type type = rpc::unpackType(raw_message); - tensor_id_type id = rpc::unpackTensorAsId(raw_message); - THLongStorage *sizes = rpc::unpackTHLongStorage(raw_message); - THLongStorage *strides = rpc::unpackTHLongStorage(raw_message); +static std::unique_ptr createStorage(Type type, std::size_t size) { + std::unique_ptr storage = createStorage(type); + storage->resize(size); + return storage; } -static void add(rpc::RPCMessage& raw_message) { -//THTensor& result = parse_tensor(raw_message); - //THTensor& source = parse_tensor(raw_message); - //double x = parse_scalar(raw_message); - //assert_end(raw_message); - //result.add(source, x); -} -static void free(rpc::RPCMessage& raw_message) { - unsigned long long tensor_id = unpackInteger(raw_message); - (void)workerTensors.erase(tensor_id); +static void finalize(rpc::RPCMessage& raw_message) { + if (raw_message.remaining() > 0) + throw std::invalid_argument("message is too long"); } +#include "dispatch/Storage.cpp" +#include "dispatch/Tensor.cpp" +#include "dispatch/Communication.cpp" + using dispatch_fn = void (*)(rpc::RPCMessage&); using Functions = thd::Functions; @@ -75,7 +89,12 @@ static const std::unordered_map functions { {Functions::construct, construct}, {Functions::constructWithSize, constructWithSize}, {Functions::add, add}, - {Functions::free, free} + {Functions::free, free}, + {Functions::storageConstruct, storageConstruct}, + {Functions::storageConstructWithSize, storageConstructWithSize}, + {Functions::storageFree, storageFree}, + {Functions::sendTensor, sendTensor}, + {Functions::sendStorage, sendStorage} }; } // namespace detail @@ -89,7 +108,7 @@ std::string execute(std::unique_ptr raw_message_ptr) { if (iter != detail::functions.end()) (*iter->second)(raw_message); else - throw std::invalid_argument("invalid function id"); + throw std::invalid_argument(std::string("invalid function id: ") + std::to_string(fid)); return std::string(); } catch(std::exception& e) { return std::string(e.what()); diff --git a/torch/lib/THD/master_worker/worker/Worker.cpp b/torch/lib/THD/master_worker/worker/Worker.cpp index f412dfaeccc6db..6d3ccbf9d87618 100644 --- a/torch/lib/THD/master_worker/worker/Worker.cpp +++ b/torch/lib/THD/master_worker/worker/Worker.cpp @@ -1,6 +1,8 @@ +#include "../../process_group/General.hpp" +#include "../../base/Storage.hpp" #include "../../base/Tensor.hpp" -#include "Dispatch.hpp" #include "../common/RPC.hpp" +#include "Dispatch.hpp" #include "Worker.h" #include "Worker.hpp" @@ -8,18 +10,23 @@ namespace thd { namespace worker { std::unique_ptr workerCommandChannel; -std::unordered_map> workerTensors; +std::unordered_map> workerTensors; +std::unordered_map> workerStorages; } // namespace worker } // namespace thd +using namespace thd; + void THDWorkerMain() { // TODO: initialize worker - thd::worker::workerCommandChannel = - std::unique_ptr(new thd::WorkerCommandChannel(1)); + worker::workerCommandChannel.reset( + new WorkerCommandChannel(dataChannel->getRank())); std::unique_ptr command; for (;;) { - command = thd::worker::workerCommandChannel->recvMessage(); - thd::worker::execute(std::move(command)); + command = worker::workerCommandChannel->recvMessage(); + auto msg = worker::execute(std::move(command)); + if (msg != "") + fprintf(stderr, "WORKER %d: %s\n", (int)dataChannel->getRank(), msg.c_str()); } } diff --git a/torch/lib/THD/master_worker/worker/Worker.hpp b/torch/lib/THD/master_worker/worker/Worker.hpp index 7dae7efa815f69..847600bf58848d 100644 --- a/torch/lib/THD/master_worker/worker/Worker.hpp +++ b/torch/lib/THD/master_worker/worker/Worker.hpp @@ -1,11 +1,13 @@ #pragma once #include "../common/CommandChannel.hpp" +#include "../../base/Storage.hpp" #include "../../base/Tensor.hpp" #include namespace thd { namespace worker { extern std::unique_ptr workerCommandChannel; -extern std::unordered_map> workerTensors; +extern std::unordered_map> workerTensors; +extern std::unordered_map> workerStorages; }} // namespace worker, thd diff --git a/torch/lib/THD/master_worker/worker/dispatch/Communication.cpp b/torch/lib/THD/master_worker/worker/dispatch/Communication.cpp new file mode 100644 index 00000000000000..95dbfa5e5211c2 --- /dev/null +++ b/torch/lib/THD/master_worker/worker/dispatch/Communication.cpp @@ -0,0 +1,14 @@ + +static void sendTensor(rpc::RPCMessage& raw_message) { + Tensor *tensor = unpackRetrieveTensor(raw_message); + int dst_rank = unpackInteger(raw_message); + finalize(raw_message); + dataChannel->send(*tensor, dst_rank); +} + +static void sendStorage(rpc::RPCMessage& raw_message) { + Storage *storage = unpackRetrieveStorage(raw_message); + int dst_rank = unpackInteger(raw_message); + finalize(raw_message); + fprintf(stderr, "sending storage (to be implemented)\n"); +} diff --git a/torch/lib/THD/master_worker/worker/dispatch/Storage.cpp b/torch/lib/THD/master_worker/worker/dispatch/Storage.cpp new file mode 100644 index 00000000000000..d0d5f53fa1bd5b --- /dev/null +++ b/torch/lib/THD/master_worker/worker/dispatch/Storage.cpp @@ -0,0 +1,27 @@ + +static void storageConstruct(rpc::RPCMessage& raw_message) { + Type storage_type = unpackType(raw_message); + object_id_type storage_id = unpackStorage(raw_message); + finalize(raw_message); + workerStorages.emplace( + storage_id, + createStorage(storage_type) + ); +} + +static void storageConstructWithSize(rpc::RPCMessage& raw_message) { + Type storage_type = unpackType(raw_message); + object_id_type storage_id = unpackStorage(raw_message); + long long size = unpackInteger(raw_message); + finalize(raw_message); + workerStorages.emplace( + storage_id, + createStorage(storage_type, size) + ); +} + +static void storageFree(rpc::RPCMessage& raw_message) { + object_id_type storage_id = unpackStorage(raw_message); + workerTensors.erase(storage_id); +} + diff --git a/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp b/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp new file mode 100644 index 00000000000000..95a0111376fe31 --- /dev/null +++ b/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp @@ -0,0 +1,32 @@ + +static void construct(rpc::RPCMessage& raw_message) { + // TODO: assert_empty(raw_message) + Type type = rpc::unpackType(raw_message); + thd::object_id_type id = rpc::unpackTensor(raw_message); + workerTensors.emplace( + id, + createTensor(type) + ); +} + +static void constructWithSize(rpc::RPCMessage& raw_message) { + // TODO: assert_empty(raw_message) + Type type = rpc::unpackType(raw_message); + object_id_type id = rpc::unpackTensor(raw_message); + THLongStorage *sizes = rpc::unpackTHLongStorage(raw_message); + THLongStorage *strides = rpc::unpackTHLongStorage(raw_message); +} + +static void add(rpc::RPCMessage& raw_message) { +//THTensor& result = parse_tensor(raw_message); + //THTensor& source = parse_tensor(raw_message); + //double x = parse_scalar(raw_message); + //assert_end(raw_message); + //result.add(source, x); +} + +static void free(rpc::RPCMessage& raw_message) { + object_id_type tensor_id = unpackInteger(raw_message); + (void)workerTensors.erase(tensor_id); +} +