Skip to content

Commit

Permalink
use memory view for PyTorch tensor objects
Browse files Browse the repository at this point in the history
  • Loading branch information
yihming committed Jun 4, 2021
1 parent a9d6e2e commit 3c2a599
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions ext_modules/nnls_bpp_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ cpdef _nnls_bpp(float[:, :] CTC, float[:, :] CTB, float[:, :] X, str device_type
Y[i, j] = -CTB[i, j]
V[i, j] = Y[i, j] < 0
Vsize[j] += V[i, j]

if Vsize[j] > 0:
I[size_I] = j
size_I += 1


CTC_L = torch.zeros((q, q), dtype=torch.float, device=device_type)
CTB_L = torch.zeros((q, r), dtype=torch.float, device=device_type)
CTC_L_tensor = torch.zeros((q, q), dtype=torch.float, device=device_type)
cdef float[:, :] CTC_L = CTC_L_tensor.numpy()
CTB_L_tensor = torch.zeros((q, r), dtype=torch.float, device=device_type)
cdef float[:, :] CTB_L = CTB_L_tensor.numpy()
cdef Py_ssize_t CTC_L_M, CTC_L_N, CTB_L_M, CTB_L_N


Expand Down Expand Up @@ -140,7 +142,7 @@ cpdef _nnls_bpp(float[:, :] CTC, float[:, :] CTB, float[:, :] X, str device_type
CTC_L_M += 1
assert CTC_L_M == CTC_L_N, f"CTC_L of shape ({CTC_L_M}, {CTC_L_N}) is not square!"

L = torch.cholesky(CTC_L[0:CTC_L_M, 0:CTC_L_N])
L = torch.cholesky(CTC_L_tensor[0:CTC_L_M, 0:CTC_L_N])

# CTB_L = CTB[Fvec, Ii]
CTB_L_M = 0
Expand All @@ -153,7 +155,7 @@ cpdef _nnls_bpp(float[:, :] CTC, float[:, :] CTB, float[:, :] X, str device_type
CTB_L_M += 1
assert CTB_L_M==CTC_L_M and CTB_L_N==size_Ii, f"CTB_L has shape ({CTB_L_M}, {CTB_L_N}), but expect ({CTC_L_M}, {size_Ii})."

x = torch.cholesky_solve(CTB_L[0:CTB_L_M, 0:CTB_L_N], L)
x = torch.cholesky_solve(CTB_L_tensor[0:CTB_L_M, 0:CTB_L_N], L)

k = 0
for i in range(q):
Expand All @@ -176,7 +178,7 @@ cpdef _nnls_bpp(float[:, :] CTC, float[:, :] CTB, float[:, :] X, str device_type
CTC_L_N += 1
CTC_L_M += 1
assert CTC_L_M + CTC_L_N == q, "CTC_L has shape ({CTC_L_M}, {CTC_L_N}), but expect sum of dims = {q}!"
y_tensor = CTC_L[0:CTC_L_M, 0:CTC_L_N] @ x
y_tensor = CTC_L_tensor[0:CTC_L_M, 0:CTC_L_N] @ x
y = y_tensor.cpu().numpy()

k = 0
Expand Down

0 comments on commit 3c2a599

Please sign in to comment.