-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathinit_multigrid.py
59 lines (46 loc) · 1.98 KB
/
init_multigrid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import tigre
from numpy.lib.stride_tricks import as_strided
from scipy.interpolate import RegularGridInterpolator
def init_multigrid(proj, geo, alpha, alg):
# WARNING: This takes a lot of memory!
if alg == "SART":
italg = tigre.algorithms.sart
if alg == "SIRT":
italg = tigre.algorithms.sirt
finalsize = geo.nVoxel
maxval = max(proj.ravel())
minval = min(proj.ravel())
# Start with 16 (raise this for larger images)
geo.nVoxel = np.array([16, 16, 16])
geo.dVoxel = geo.sVoxel / geo.nVoxel
if (geo.nVoxel > finalsize).all():
return np.zeros(finalsize, dtype=np.float32)
niter = 100
initres = np.zeros(geo.nVoxel, dtype=np.float32)
while (geo.nVoxel != finalsize).all():
geo.dVoxel = geo.sVoxel / geo.nVoxel
initres = italg(proj, geo, alpha, niter, init=initres, verbose=False)
# get new dims(should be a way to do this more efficiently).
geo.nVoxel = geo.nVoxel * 2
geo.nVoxel[geo.nVoxel > finalsize] = finalsize[geo.nVoxel > finalsize]
geo.dVoxel = geo.sVoxel / geo.nVoxel
(x, y, z) = (
np.linspace(minval, maxval, geo.nVoxel[0] / 2, dtype=np.float32),
np.linspace(minval, maxval, geo.nVoxel[1] / 2, dtype=np.float32),
np.linspace(minval, maxval, geo.nVoxel[2] / 2, dtype=np.float32),
)
# evaluate the function sart at the points xv,yv,zv
xv, yv, zv = [
tile_array(tile_array(x, 2), geo.nVoxel[0] ** 2),
tile_array(tile_array(y, 2), geo.nVoxel[0] ** 2),
tile_array(tile_array(x, 2), geo.nVoxel[0] ** 2),
]
initres = RegularGridInterpolator((x, y, z), initres)(np.column_stack((xv, yv, zv)))
initres = initres.reshape(geo.nVoxel)
return initres
def tile_array(mat, b1):
(r,) = mat.shape
(rs,) = mat.strides
x = as_strided(mat, (r, b1), (rs, 0))
return x.reshape(r * b1)