Skip to content

Commit

Permalink
remove inplace pow and fix contiguous -> coalesce (pytorch#1398)
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith authored Apr 28, 2017
1 parent 9c01f5d commit 45020a7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 57 deletions.
3 changes: 2 additions & 1 deletion test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,10 @@ def _test_basic_ops_shape(self, is_cuda, shape_i, shape_v=None):
self.assertEqual(y1.to_dense(), expected)
self.assertEqual(y2.to_dense(), expected)

# TODO: add back inplace support
y1 = x1 ** 2
y2 = x1.clone()
y2.pow_(2)
y2 = y2.pow(2)
expected = x1.to_dense() ** 2
self.assertEqual(y1.to_dense(), expected)
self.assertEqual(y2.to_dense(), expected)
Expand Down
12 changes: 0 additions & 12 deletions torch/csrc/generic/methods/SparseTensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -413,18 +413,6 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
- real value
]]

[[
name: pow_
defined_if: defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || CUDA_FLOAT || CUDA_DOUBLE
sparse: yes
cname: pow
return: argument 0
arguments:
- THSTensor* self
- THSTensor* self
- real value
]]

[[
name: sparse_mask
cname: sparseMask
Expand Down
39 changes: 17 additions & 22 deletions torch/lib/THCS/generic/THCSTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,34 +460,29 @@ void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t_, THCSTens
}

#if defined(THCS_REAL_IS_FLOAT) || defined(THCS_REAL_IS_DOUBLE)
void THCSTensor_(pow)(THCState *state, THCSTensor *r_, THCSTensor *t, real value) {
void THCSTensor_(pow)(THCState *state, THCSTensor *r_, THCSTensor *t_, real value) {
if (value == 0) {
THError("cannot raise to zeroth power on sparse tensor");
}
THCSTensor_(contiguous)(state, t);
if (r_ == t) {
THCTensor *r_values_ = THCSTensor_(newValues)(state, r_);
THCTensor_(pow)(state, r_values_, r_values_, value);
THCTensor_(free)(state, r_values_);
} else {
THCSTensor_(resizeAs)(state, r_, t);
THCSTensor *t = THCSTensor_(newCoalesce)(state, t_);
THCSTensor_(resizeAs)(state, r_, t);

THCIndexTensor *r_indices_ = THCSTensor_(newIndices)(state, r_);
THCTensor *r_values_ = THCSTensor_(newValues)(state, r_);
THCIndexTensor *t_indices_ = THCSTensor_(newIndices)(state, t);
THCTensor *t_values_ = THCSTensor_(newValues)(state, t);
THCIndexTensor *r_indices_ = THCSTensor_(newIndices)(state, r_);
THCTensor *r_values_ = THCSTensor_(newValues)(state, r_);
THCIndexTensor *t_indices_ = THCSTensor_(newIndices)(state, t);
THCTensor *t_values_ = THCSTensor_(newValues)(state, t);

THCIndexTensor_(resizeAs)(state, r_indices_, t_indices_);
THCIndexTensor_(copy)(state, r_indices_, t_indices_);
THCTensor_(pow)(state, r_values_, t_values_, value);
r_->nnz = t->nnz;
r_->contiguous = t->contiguous;
THCIndexTensor_(resizeAs)(state, r_indices_, t_indices_);
THCIndexTensor_(copy)(state, r_indices_, t_indices_);
THCTensor_(pow)(state, r_values_, t_values_, value);
r_->nnz = t->nnz;
r_->coalesced = t->coalesced;

THCIndexTensor_(free)(state, r_indices_);
THCTensor_(free)(state, r_values_);
THCIndexTensor_(free)(state, t_indices_);
THCTensor_(free)(state, t_values_);
}
THCIndexTensor_(free)(state, r_indices_);
THCTensor_(free)(state, r_values_);
THCIndexTensor_(free)(state, t_indices_);
THCTensor_(free)(state, t_values_);
THCSTensor_(free)(state, t);
}
#endif

Expand Down
43 changes: 21 additions & 22 deletions torch/lib/THS/generic/THSTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,34 @@ void THSTensor_(mul)(THSTensor *r_, THSTensor *t, real value) {
}

/* floating point only, because that is what TH supports */
/* TODO: add in-place support */
#if defined(THS_REAL_IS_FLOAT) || defined(THS_REAL_IS_DOUBLE)
void THSTensor_(pow)(THSTensor *r_, THSTensor *t, real value) {
void THSTensor_(pow)(THSTensor *r_, THSTensor *t_, real value) {
if (value == 0) {
THError("cannot raise to zeroth power on sparse tensor");
}
THSTensor_(contiguous)(t);
if (r_ == t) {
THTensor *r_values_ = THSTensor_(newValues)(r_);
THTensor_(pow)(r_values_, r_values_, value);
THTensor_(free)(r_values_);
} else {
THSTensor_(resizeAs)(r_, t);

THLongTensor *r_indices_ = THSTensor_(newIndices)(r_);
THTensor *r_values_ = THSTensor_(newValues)(r_);
THLongTensor *t_indices_ = THSTensor_(newIndices)(t);
THTensor *t_values_ = THSTensor_(newValues)(t);
THSTensor* t = THSTensor_(newCoalesce)(t_);

THLongTensor_resizeAs(r_indices_, t_indices_);
THLongTensor_copy(r_indices_, t_indices_);
THTensor_(pow)(r_values_, t_values_, value);
r_->nnz = t->nnz;
r_->contiguous = t->contiguous;
THSTensor_(resizeAs)(r_, t);

THLongTensor_free(r_indices_);
THTensor_(free)(r_values_);
THLongTensor_free(t_indices_);
THTensor_(free)(t_values_);
}
THLongTensor *r_indices_ = THSTensor_(newIndices)(r_);
THTensor *r_values_ = THSTensor_(newValues)(r_);
THLongTensor *t_indices_ = THSTensor_(newIndices)(t);
THTensor *t_values_ = THSTensor_(newValues)(t);

THLongTensor_resizeAs(r_indices_, t_indices_);
THLongTensor_copy(r_indices_, t_indices_);
THTensor_(pow)(r_values_, t_values_, value);
r_->nnz = t->nnz;
r_->coalesced = t->coalesced;

THLongTensor_free(r_indices_);
THTensor_(free)(r_values_);
THLongTensor_free(t_indices_);
THTensor_(free)(t_values_);

THSTensor_(free)(t);
}
#endif

Expand Down

0 comments on commit 45020a7

Please sign in to comment.