Skip to content

Commit

Permalink
add DynamicType variants for ATen functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
zdevito authored and soumith committed Jul 11, 2017
1 parent 9d8cff9 commit 2ecb188
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
57 changes: 57 additions & 0 deletions torch/csrc/DynamicTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,53 @@ static std::unordered_map<std::string, Type> type_names = {
{"Int", Type::INT},
{"Long", Type::LONG},
};

static std::unordered_map<std::string, at::ScalarType> attype_names = {
{"Float", at::kFloat},
{"Double", at::kDouble},
{"Half", at::kHalf},
{"Byte", at::kByte},
{"Char", at::kChar},
{"Short", at::kShort},
{"Int", at::kInt},
{"Long", at::kLong},
};
static std::unordered_map<PyTypeObject*, TensorType> pytype_to_tensortype;
static std::unordered_map<TensorType, PyTypeObject*, TensorTypeHasher> tensortype_to_pytype;

static std::unordered_map<PyTypeObject*, at::Type*> pytype_to_attype;
static std::unordered_map<at::Type*, PyTypeObject*> attype_to_pytype;

void registerPyTypeObject(PyTypeObject *pytype, const std::string& name, bool is_cuda, bool is_sparse)
{
TensorType type;
at::Backend device;
if(is_cuda) {
if(is_sparse){
device = at::kSparseCUDA;
} else {
device = at::kCUDA;
}
} else {
if(is_sparse){
device = at::kSparseCPU;
} else {
device = at::kCPU;
}
}

type.data_type = type_names.at(name);
type.is_cuda = is_cuda;
type.is_sparse = is_sparse;

pytype_to_tensortype[pytype] = type;
tensortype_to_pytype[type] = pytype;

if(!(is_sparse && name == "Half")) {
at::Type * attype = &at::getType(device,attype_names.at(name));
pytype_to_attype[pytype] = attype;
attype_to_pytype[attype] = pytype;
}
}

PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor)
Expand All @@ -81,6 +116,12 @@ PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor)

return tensortype_to_pytype.at(type);
}
PyTypeObject* getPyTypeObject(const at::Tensor& tensor)
{
if(attype_to_pytype.count(&tensor.type()) == 0)
throw std::invalid_argument("unsupported Tensor type.");
return attype_to_pytype.at(&tensor.type());
}

static std::unique_ptr<Tensor> createTensor(void *tensor, Type type, bool is_cuda, bool is_sparse)
{
Expand Down Expand Up @@ -167,6 +208,22 @@ std::unique_ptr<Tensor> createTensor(PyObject *data)
wrapper->retain();
return wrapper;
}
//rename to createTensor when THPP is removed
at::Tensor createTensorAT(PyObject *data)
{
auto tensor_type = pytype_to_attype.at(Py_TYPE(data));
auto tensor = ((THPVoidTensor *)data)->cdata;
return tensor_type->unsafeTensorFromTH(tensor);
}
PyObject* createPyObject(at::Tensor tensor)
{
auto type = getPyTypeObject(tensor);
PyObject *obj = type->tp_alloc(type, 0);
if (obj) {
((THPVoidTensor*)obj)->cdata = (THVoidTensor *)tensor.detach()->unsafeGetTH();
}
return obj;
}

PyObject* createPyObject(const thpp::Tensor& tensor)
{
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/DynamicTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Python.h>
#include <memory>
#include <THPP/THPP.h>
#include <ATen/ATen.h>

namespace torch {

Expand All @@ -22,4 +23,9 @@ std::unique_ptr<thpp::Tensor> createTensor(PyObject *data);
// Creates Python tensor object from a Tensor
PyObject* createPyObject(const thpp::Tensor& tensor);

PyObject* createPyObject(at::Tensor tensor);
PyTypeObject* getPyTypeObject(const at::Tensor& tensor);
//rename to createPyObject when THPP is removed
at::Tensor createTensorAT(PyObject *data);

} // namespace torch

0 comments on commit 2ecb188

Please sign in to comment.