forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Lang] Experimental sparse matrix support on CPUs (taichi-dev#2792)
Co-authored-by: Yuanming Hu <[email protected]> Co-authored-by: Ye Kuang <[email protected]>
- Loading branch information
1 parent
02e923b
commit fa45dbb
Showing
16 changed files
with
612 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import taichi as ti | ||
|
||
ti.init(arch=ti.x64) | ||
|
||
n = 8 | ||
|
||
K = ti.SparseMatrixBuilder(n, n, max_num_triplets=100) | ||
f = ti.SparseMatrixBuilder(n, 1, max_num_triplets=100) | ||
|
||
|
||
@ti.kernel | ||
def fill(A: ti.sparse_matrix_builder(), b: ti.sparse_matrix_builder(), | ||
interval: ti.i32): | ||
for i in range(n): | ||
if i > 0: | ||
A[i - 1, i] += -1.0 | ||
A[i, i] += 1 | ||
if i < n - 1: | ||
A[i + 1, i] += -1.0 | ||
A[i, i] += 1.0 | ||
|
||
if i % interval == 0: | ||
b[i, 0] += 1.0 | ||
|
||
|
||
fill(K, f, 3) | ||
|
||
print(">>>> K.print_triplets()") | ||
K.print_triplets() | ||
|
||
A = K.build() | ||
|
||
print(">>>> A = K.build()") | ||
print(A) | ||
|
||
print(">>>> Summation: C = A + A") | ||
C = A + A | ||
print(C) | ||
|
||
print(">>>> Subtraction: D = A - A") | ||
D = A - A | ||
print(D) | ||
|
||
print(">>>> Multiplication with a scalar on the right: E = A * 3.0") | ||
E = A * 3.0 | ||
print(E) | ||
|
||
print(">>>> Multiplication with a scalar on the left: E = 3.0 * A") | ||
E = 3.0 * A | ||
print(E) | ||
|
||
print(">>>> Transpose: F = A.transpose()") | ||
F = A.transpose() | ||
print(F) | ||
|
||
print(">>>> Matrix multiplication: G = E @ A") | ||
G = E @ A | ||
print(G) | ||
|
||
print(">>>> Element-wise multiplication: H = E * A") | ||
H = E * A | ||
print(H) | ||
|
||
print(f">>>> Element Access: A[0,0] = {A[0,0]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
class SparseMatrix: | ||
def __init__(self, n=None, m=None, sm=None): | ||
if sm is None: | ||
self.n = n | ||
self.m = m if m else n | ||
from taichi.core.util import ti_core as _ti_core | ||
self.matrix = _ti_core.create_sparse_matrix(n, m) | ||
else: | ||
self.n = sm.num_rows() | ||
self.m = sm.num_cols() | ||
self.matrix = sm | ||
|
||
def __add__(self, other): | ||
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" | ||
sm = self.matrix + other.matrix | ||
return SparseMatrix(sm=sm) | ||
|
||
def __sub__(self, other): | ||
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" | ||
sm = self.matrix - other.matrix | ||
return SparseMatrix(sm=sm) | ||
|
||
def __mul__(self, other): | ||
if isinstance(other, float): | ||
sm = self.matrix * other | ||
return SparseMatrix(sm=sm) | ||
elif isinstance(other, SparseMatrix): | ||
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" | ||
sm = self.matrix * other.matrix | ||
return SparseMatrix(sm=sm) | ||
|
||
def __rmul__(self, other): | ||
if isinstance(other, float): | ||
sm = other * self.matrix | ||
return SparseMatrix(sm=sm) | ||
|
||
def transpose(self): | ||
sm = self.matrix.transpose() | ||
return SparseMatrix(sm=sm) | ||
|
||
def __matmul__(self, other): | ||
assert self.m == other.n, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" | ||
sm = self.matrix.matmul(other.matrix) | ||
return SparseMatrix(sm=sm) | ||
|
||
def __getitem__(self, indices): | ||
return self.matrix.get_element(indices[0], indices[1]) | ||
|
||
def __str__(self): | ||
return self.matrix.to_string() | ||
|
||
def __repr__(self): | ||
return self.matrix.to_string() | ||
|
||
|
||
class SparseMatrixBuilder: | ||
def __init__(self, num_rows=None, num_cols=None, max_num_triplets=0): | ||
self.num_rows = num_rows | ||
self.num_cols = num_cols if num_cols else num_rows | ||
if num_rows is not None: | ||
from taichi.core.util import ti_core as _ti_core | ||
self.ptr = _ti_core.create_sparse_matrix_builder( | ||
num_rows, num_cols, max_num_triplets) | ||
|
||
def get_addr(self): | ||
return self.ptr.get_addr() | ||
|
||
def print_triplets(self): | ||
self.ptr.print_triplets() | ||
|
||
def build(self): | ||
sm = self.ptr.build() | ||
return SparseMatrix(sm=sm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.