Skip to content

Commit

Permalink
revert sparse cuda index type change
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Raison authored and soumith committed Apr 18, 2017
1 parent 88b4232 commit 01d84c5
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 38 deletions.
2 changes: 1 addition & 1 deletion test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
type_triplets = [cpu_triplet]
if torch.cuda.is_available():
cuda_triplet = (
torch.cuda.IntTensor,
torch.cuda.LongTensor,
torch.cuda.DoubleTensor,
torch.cuda.sparse.DoubleTensor)
type_triplets.append(cuda_triplet)
Expand Down
2 changes: 1 addition & 1 deletion torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _cuda(self, device=None, async=False):
with torch.cuda.device(device):
if self.is_sparse:
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
indices = self.indices().cuda(device, async).int()
indices = self.indices().cuda(device, async)
values = self.values().cuda(device, async)
return new_type(indices, values, self.size())
else:
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/generic/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ static void THSPTensor_(dealloc)(THSPTensor* self)
static PyObject * THSPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
#ifdef THC_GENERIC_FILE
#define THPIndexTensor_Check THCPIntTensor_Check
#define THPIndexTensor THCPIntTensor
#define THIndexTensor THCudaIntTensor
#define THPIndexTensor_Check THCPLongTensor_Check
#define THPIndexTensor THCPLongTensor
#define THIndexTensor THCudaLongTensor
#else
#define THPIndexTensor_Check THPLongTensor_Check
#define THPIndexTensor THPLongTensor
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/generic/TensorMethods.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
#define THIndexTensor THCudaLongTensor
#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME)
#define THPIndexTensor THCPLongTensor
#define THPIndexTensor_(NAME) TH_CONCAT_2(THCPLongTensor_,NAME)
#define THPIndexTensorClass THCPLongTensorClass
#else
#define THIndexTensor THLongTensor
#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME)
#define THPIndexTensor THPLongTensor
#define THPIndexTensor_(NAME) TH_CONCAT_2(THPLongTensor_,NAME)
#define THPIndexTensorClass THPLongTensorClass
#endif

Expand Down
12 changes: 1 addition & 11 deletions torch/csrc/generic/methods/SparseTensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,8 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)

[[
name: indices
defined_if: "!IS_CUDA"
sparse: yes
return: THLongTensor*
arguments:
- THSTensor* self
]]

[[
name: indices
defined_if: "IS_CUDA"
sparse: yes
return: THCudaIntTensor*
return: THIndexTensor*
arguments:
- THSTensor* self
]]
Expand Down
7 changes: 3 additions & 4 deletions torch/lib/THCS/THCSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
#define THCSTensor TH_CONCAT_3(THCS,Real,Tensor)
#define THCSTensor_(NAME) TH_CONCAT_4(THCS,Real,Tensor_,NAME)

// Using int for indices because that's what cuSparse uses...
#define THCIndexTensor THCudaIntTensor
#define THCIndexTensor_(NAME) THCudaIntTensor_ ## NAME
#define integer int
#define THCIndexTensor THCudaLongTensor
#define THCIndexTensor_(NAME) THCudaLongTensor_ ## NAME
#define integer long

#include "generic/THCSTensor.h"
#include "THCSGenerateAllTypes.h"
Expand Down
9 changes: 2 additions & 7 deletions torch/lib/THCS/generic/THCSTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ THCSTensor *THCSTensor_(newWithTensorAndSize)(THCState *state, THCIndexTensor *i
// TODO make sure this doesn't sync the hell out of everything
// Should be fine according to sam's memory manager.
computed_sizes = THLongTensor_newWithSize(THCIndexTensor_(newSizeOf)(state, s), NULL);
THLongTensor_copyCudaInt(state, computed_sizes, s);
THLongTensor_copyCudaLong(state, computed_sizes, s);
THCSTensor_(rawResize)(state, self, nDimI, nDimV, THLongTensor_data(computed_sizes));

THCIndexTensor_(free)(state, s);
Expand Down Expand Up @@ -424,14 +424,10 @@ void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCST
THCudaLongTensor *indices = THCudaLongTensor_newWithSize1d(state, mask->nnz);
THCudaLongTensor *indicesBuffer = THCudaLongTensor_new(state);

// FIXME remove after fixing CUDA index type
THCudaLongTensor *maskIndicesLong = THCudaLongTensor_newWithSize2d(state, maskIndices->size[0], maskIndices->size[1]);
THCudaLongTensor_copyCudaInt(state, maskIndicesLong, maskIndices);

THCudaLongTensor_zero(state, indices);
for (long d = 0; d < mask->nDimensionI; d++) {
THCudaLongTensor_mul(state, indices, indices, mask->size[d]);
THCudaLongTensor_select(state, indicesBuffer, maskIndicesLong, 0, d);
THCudaLongTensor_select(state, indicesBuffer, maskIndices, 0, d);
THCudaLongTensor_cadd(state, indices, indices, 1, indicesBuffer);
}
THLongStorage *viewSize = THLongStorage_newWithSize(1 + mask->nDimensionV);
Expand All @@ -442,7 +438,6 @@ void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCST
THCTensor *t_view = THCTensor_(newView)(state, t, viewSize);
THCTensor_(indexSelect)(state, rValues, t_view, 0, indices);

THCudaLongTensor_free(state, maskIndicesLong);
THCudaLongTensor_free(state, indices);
THCudaLongTensor_free(state, indicesBuffer);
THLongStorage_free(viewSize);
Expand Down
34 changes: 23 additions & 11 deletions torch/lib/THCS/generic/THCSTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
#define I_INFO(tensor) getTensorInfo<THCIndexTensor, unsigned long>(state, tensor)
#define V_INFO(tensor) getTensorInfo<THCTensor, unsigned long>(state, tensor)

THCIndexTensor *THCSTensor_(toCSR)(THCState *state, integer const *indices, long dim, long nnz) {
THCIndexTensor *csr = THCIndexTensor_(newWithSize1d)(state, dim + 1);
THCudaSparse_Xcoo2csr(state, indices, nnz, dim, THCIndexTensor_(data)(state, csr));
THCudaIntTensor *THCSTensor_(toCSR)(THCState *state, THCIndexTensor *rowIndices, long dim, long nnz) {
THCudaIntTensor *csr = THCudaIntTensor_newWithSize1d(state, dim + 1);
THCudaIntTensor *rowIndicesInt = THCudaIntTensor_newWithSize1d(state, rowIndices->size[0]);
THCudaIntTensor_copyCudaLong(state, rowIndicesInt, rowIndices);
THCudaSparse_Xcoo2csr(
state, THCudaIntTensor_data(state, rowIndicesInt), nnz, dim, THCudaIntTensor_data(state, csr));
THCudaIntTensor_free(state, rowIndicesInt);
return csr;
}

Expand All @@ -35,7 +39,8 @@ void THCTensor_(spaddcdiv)(THCState *state, THCTensor *r_, THCTensor *t, real va
void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCSTensor *sparse, THCTensor *dense) {
#if defined(THCS_REAL_IS_FLOAT) || defined(THCS_REAL_IS_DOUBLE)
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 1, 4, sparse, r_, t, dense));
THCIndexTensor *csr, *indices;
THCudaIntTensor *csr;
THCIndexTensor *indices;
THCTensor *values, *r__, *dense_;

THArgCheck(sparse->nDimensionI == 2, 2,
Expand All @@ -62,9 +67,14 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
indices = THCSTensor_(indices)(state, sparse);
values = THCSTensor_(values)(state, sparse);

csr = THCSTensor_(toCSR)(state, THCIndexTensor_(data)(state, indices), m, nnz);
THCIndexTensor *colindices = THCIndexTensor_(new)(state);
THCIndexTensor_(select)(state, colindices, indices, 0, 1);
THCIndexTensor *rowIndices = THCIndexTensor_(new)(state);
THCIndexTensor *colIndices = THCIndexTensor_(new)(state);
THCIndexTensor_(select)(state, rowIndices, indices, 0, 0);
THCIndexTensor_(select)(state, colIndices, indices, 0, 1);
csr = THCSTensor_(toCSR)(state, rowIndices, m, nnz);
THCudaIntTensor *colIndicesInt = THCudaIntTensor_newWithSize1d(state, colIndices->size[0]);
THCudaIntTensor_copyCudaLong(state, colIndicesInt, colIndices);


char transpose_dense;

Expand Down Expand Up @@ -109,8 +119,8 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
nnz,
alpha,
THCTensor_(data)(state, values),
THCIndexTensor_(data)(state, csr),
THCIndexTensor_(data)(state, colindices),
THCudaIntTensor_data(state, csr),
THCudaIntTensor_data(state, colIndicesInt),
THCTensor_(data)(state, dense_),
(transpose_dense == 'n' ? dense_->stride[1] : dense_->stride[0]),
beta,
Expand All @@ -126,9 +136,11 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
THCTensor_(freeCopyTo)(state, r__, r_);
}

THCIndexTensor_(free)(state, csr);
THCudaIntTensor_free(state, colIndicesInt);
THCudaIntTensor_free(state, csr);
THCIndexTensor_(free)(state, indices);
THCIndexTensor_(free)(state, colindices);
THCIndexTensor_(free)(state, rowIndices);
THCIndexTensor_(free)(state, colIndices);
THCTensor_(free)(state, values);
#else
THError("unimplemented data type");
Expand Down

0 comments on commit 01d84c5

Please sign in to comment.