forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_external_func.py
77 lines (59 loc) · 1.37 KB
/
demo_external_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import ctypes
import os
import taichi as ti
ti.init()
N = 1024
x = ti.field(ti.i32, shape=N)
y = ti.field(ti.i32, shape=N)
z = ti.field(ti.i32, shape=N)
source = '''
extern "C" {
void add_and_mul(float a, float b, float *c, float *d, int *e) {
*c = a + b;
*d = a * b;
*e = int(a * b + a);
}
void pow_int(int a, int b, int *c) {
int ret = 1;
for (int i = 0; i < b; i++)
ret = ret * a;
*c = ret;
}
}
'''
with open('a.cpp', 'w') as f:
f.write(source)
os.system("g++ a.cpp -o a.so -fPIC -shared")
so = ctypes.CDLL("./a.so")
@ti.kernel
def call_ext() -> ti.i32:
a = 2.0
b = 3.0
c = 0.0
d = 0.0
e = 3
ti.external_func_call(func=so.add_and_mul, args=(a, b), outputs=(c, d, e))
p = 0
ti.external_func_call(func=so.pow_int, args=(int(c + d), e), outputs=(p, ))
return p
# Wrap the external function to make it easier to use
@ti.func
def pow_int_wrapper(a, b):
p = 0
ti.external_func_call(func=so.pow_int,
args=(int(a), int(b)),
outputs=(p, ))
return p
@ti.kernel
def call_parallel():
for i in range(N):
z[i] = pow_int_wrapper(x[i], y[i])
assert call_ext() == 11**8
for i in range(N):
x[i] = i
y[i] = 3
call_parallel()
for i in range(N):
assert z[i] == i**3
os.remove('a.cpp')
os.remove('a.so')