Skip to content

Commit

Permalink
Merge branch 'neuraloperator:main' into temp/colin
Browse files Browse the repository at this point in the history
  • Loading branch information
rtu715 authored Jun 28, 2023
2 parents e1acc95 + 1aefe93 commit 3c48850
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 75 deletions.
7 changes: 4 additions & 3 deletions examples/plot_UNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@



model = UNO(3,1, hidden_channels=64, projection_channels=64,uno_out_channels = [32,64,64,64,32], uno_n_modes= [[16,16],[8,8],[8,8],[8,8],[16,16]], uno_scalings= [[1.0,1.0],[0.5,0.5],[1,1],[2,2],[1,1]],\
horizontal_skips_map = None, n_layers = 5, domain_padding = 0.2)
model = UNO(3,1, hidden_channels=64, projection_channels=64,uno_out_channels = [32,64,64,64,32], \
uno_n_modes= [[16,16],[8,8],[8,8],[8,8],[16,16]], uno_scalings= [[1.0,1.0],[0.5,0.5],[1,1],[2,2],[1,1]],\
horizontal_skips_map = None, n_layers = 5, domain_padding = 0.2)
model = model.to(device)

n_params = count_params(model)
Expand Down Expand Up @@ -116,7 +117,7 @@
# Ground-truth
y = data['y']
# Model prediction
out = model(x.unsqueeze(0))
out = model(x.unsqueeze(0).to(device)).cpu()

ax = fig.add_subplot(3, 3, index*3 + 1)
ax.imshow(x[0], cmap='gray')
Expand Down
15 changes: 8 additions & 7 deletions neuralop/models/fno_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def set_ada_in_embeddings(self, *embeddings):
for norm, embedding in zip(self.norm, embeddings):
norm.set_embedding(embedding)

def forward(self, x, index=0):
def forward(self, x, index=0, output_shape = None):

if self.preactivation:
x = self.non_linearity(x)
Expand All @@ -131,15 +131,17 @@ def forward(self, x, index=0):
x_skip_fno = self.fno_skips[index](x)
if self.convs.output_scaling_factor is not None:
# x_skip_fno = resample(x_skip_fno, self.convs.output_scaling_factor[index], list(range(-len(self.convs.output_scaling_factor[index]), 0)))
x_skip_fno = resample(x_skip_fno, self.output_scaling_factor[index], list(range(-len(self.output_scaling_factor[index]), 0)))
x_skip_fno = resample(x_skip_fno, self.output_scaling_factor[index]\
, list(range(-len(self.output_scaling_factor[index]), 0)), output_shape = output_shape )


if self.mlp is not None:
x_skip_mlp = self.mlp_skips[index](x)
if self.convs.output_scaling_factor is not None:
# x_skip_mlp = resample(x_skip_mlp, self.convs.output_scaling_factor[index], list(range(-len(self.convs.output_scaling_factor[index]), 0)))
x_skip_mlp = resample(x_skip_mlp, self.output_scaling_factor[index], list(range(-len(self.output_scaling_factor[index]), 0)))
x_skip_mlp = resample(x_skip_mlp, self.output_scaling_factor[index]\
, list(range(-len(self.output_scaling_factor[index]), 0)), output_shape = output_shape )

x_fno = self.convs(x, index)
x_fno = self.convs(x, index, output_shape = output_shape)

if not self.preactivation and self.norm is not None:
x_fno = self.norm[self.n_norms*index](x_fno)
Expand Down Expand Up @@ -206,5 +208,4 @@ def __init__(self, main_module, indices):
self.indices = indices

def forward(self, x):
return self.main_module.forward(x, self.indices)

return self.main_module.forward(x, self.indices)
29 changes: 16 additions & 13 deletions neuralop/models/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DomainPadding(nn.Module):
This class works for any input resolution, as long as it is in the form
`(batch-size, channels, d1, ...., dN)`
"""
def __init__(self, domain_padding, padding_mode='one-sided', output_scaling_factor = None):
def __init__(self, domain_padding, padding_mode='one-sided', output_scaling_factor=None):
super().__init__()
self.domain_padding = domain_padding
self.padding_mode = padding_mode.lower()
Expand Down Expand Up @@ -54,36 +54,38 @@ def pad(self, x):
padding = [int(round(p*r)) for (p, r) in zip(self.domain_padding, resolution)]

print(f'Padding inputs of {resolution=} with {padding=}, {self.padding_mode}')



output_pad = padding

for scale_factor in self.output_scaling_factor:
if isinstance(scale_factor, (float, int)):
scale_factor = [scale_factor]*len(resolution)
output_pad = [int(round(i*j)) for (i,j) in zip(scale_factor,output_pad)]
output_pad = [int(round(i*j)) for (i,j) in zip(self.output_scaling_factor,output_pad)]


# the F.pad(x, padding) funtion pads the tensor 'x' in reverse order of the "padding" list i.e. the last axis of tensor 'x' will be
# padded by the amount mention at the first position of the 'padding' vector.
# The details about F.pad can be found here : https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

if self.padding_mode == 'symmetric':
# Pad both sides
unpad_indices = (Ellipsis, ) + tuple([slice(p, -p, None) for p in output_pad ])
unpad_indices = (Ellipsis, ) + tuple([slice(p, -p, None) for p in output_pad[::-1] ])
padding = [i for p in padding for i in (p, p)]

elif self.padding_mode == 'one-sided':
# One-side padding
unpad_indices = (Ellipsis, ) + tuple([slice(None, -p, None) for p in output_pad])
unpad_indices = (Ellipsis, ) + tuple([slice(None, -p, None) for p in output_pad[::-1]])
padding = [i for p in padding for i in (0, p)]
else:
raise ValueError(f'Got {self.padding_mode=}')

self._padding[f'{resolution}'] = padding


padded = F.pad(x, padding, mode='constant')

out_put_shape = padded.shape[2:]
for scale_factor in self.output_scaling_factor:
if isinstance(scale_factor, (float, int)):
scale_factor = [scale_factor]*len(resolution)
out_put_shape = [int(round(i*j)) for (i,j) in zip(scale_factor,out_put_shape)]


out_put_shape = [int(round(i*j)) for (i,j) in zip(self.output_scaling_factor,out_put_shape)]

self._unpad_indices[f'{[i for i in out_put_shape]}'] = unpad_indices

return padded
Expand All @@ -93,3 +95,4 @@ def unpad(self, x):
"""
unpad_indices = self._unpad_indices[f'{list(x.shape[2:])}']
return x[unpad_indices]

18 changes: 11 additions & 7 deletions neuralop/models/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn.functional as F

def resample(x, res_scale, axis):
def resample(x, res_scale, axis, output_shape=None):
"""
A module for generic n-dimentional interpolation (Fourier resampling).
Expand All @@ -17,6 +17,7 @@ def resample(x, res_scale, axis):
scaling is performed
axis: axis or dimensions along which interpolation will be performed.
"""

if isinstance(res_scale, (float, int)):
if axis is None:
axis = list(range(2, x.ndim))
Expand All @@ -30,12 +31,15 @@ def resample(x, res_scale, axis):
assert len(res_scale) == len(axis), "leght of res_scale and axis are not same"

old_size = x.shape[-len(axis):]
new_size = tuple([int(round(s*r)) for (s, r) in zip(old_size, res_scale)])
if output_shape is None:
new_size = tuple([int(round(s*r)) for (s, r) in zip(old_size, res_scale)])
else:
new_size = output_shape

if len(axis) == 1:
return F.interpolate(x, size = new_size[0], mode = 'linear', align_corners = True)
return F.interpolate(x, size=new_size[0], mode='linear', align_corners=True)
if len(axis) == 2:
return F.interpolate(x, size = new_size, mode = 'bicubic', align_corners = True, antialias = True)
return F.interpolate(x, size=new_size, mode='bicubic', align_corners=True)

X = torch.fft.rfftn(x.float(), norm='forward', dim=axis)

Expand All @@ -50,7 +54,7 @@ def resample(x, res_scale, axis):
idx_tuple = [slice(None), slice(None)] + [slice(*b) for b in boundaries]

out_fft[idx_tuple] = X[idx_tuple]
y = torch.fft.irfftn(out_fft, s = new_size ,norm='forward', dim = axis)
y = torch.fft.irfftn(out_fft, s= new_size ,norm='forward', dim=axis)

return y

Expand All @@ -71,7 +75,7 @@ def iterative_resample(x, res_scale, axis):
return x

old_res = x.shape[axis]
X = torch.fft.rfft(x, dim=axis, norm = 'forward')
X = torch.fft.rfft(x, dim=axis, norm='forward')
newshape = list(x.shape)
new_res = int(round(res_scale*newshape[axis]))
newshape[axis] = new_res // 2 + 1
Expand All @@ -82,6 +86,6 @@ def iterative_resample(x, res_scale, axis):
sl = [slice(None)] * x.ndim
sl[axis] = slice(0, modes // 2 + 1)
Y[tuple(sl)] = X[tuple(sl)]
y = torch.fft.irfft(Y, n = new_res, dim=axis,norm = 'forward')
y = torch.fft.irfft(Y, n=new_res, dim=axis,norm='forward')
return y

15 changes: 9 additions & 6 deletions neuralop/models/spectral_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def incremental_n_modes(self, incremental_n_modes):
self.weight_slices = [slice(None)]*2 + [slice(None, n//2) for n in self._incremental_n_modes]
self.half_n_modes = [m//2 for m in self._incremental_n_modes]

def forward(self, x, indices=0):
def forward(self, x, indices=0, output_shape = None):
"""Generic forward pass for the Factorized Spectral Conv
Parameters
Expand Down Expand Up @@ -336,8 +336,12 @@ def forward(self, x, indices=0):
# For 2D: [:, :, :height, :width] and [:, :, -height:, width]
out_fft[idx_tuple] = self._contract(x[idx_tuple], self._get_weight(self.n_weights_per_layer*indices + i), separable=self.separable)

if self.output_scaling_factor is not None:
if self.output_scaling_factor is not None and output_shape is None:
mode_sizes = tuple([int(round(s*r)) for (s, r) in zip(mode_sizes, self.output_scaling_factor[indices])])

if output_shape is not None:
mode_sizes = output_shape


x = torch.fft.irfftn(out_fft, s=(mode_sizes), norm=self.fft_norm)

Expand Down Expand Up @@ -416,8 +420,8 @@ def forward(self, x, indices=0):
self._get_weight(2*indices + 1), separable=self.separable)

if self.output_scaling_factor is not None:
width = int(round(width*self.output_scaling_factor[0]))
height = int(round(height*self.output_scaling_factor[1]))
width = int(round(width*self.output_scaling_factor[indices][0]))
height = int(round(height*self.output_scaling_factor[indices][1]))

x = torch.fft.irfft2(out_fft, s=(height, width), dim=(-2, -1), norm=self.fft_norm)

Expand Down Expand Up @@ -453,5 +457,4 @@ def forward(self, x, indices=0):

if self.bias is not None:
x = x + self.bias[indices, ...]

return x
return x
1 change: 1 addition & 0 deletions neuralop/models/tests/test_tfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ def test_fno_superresolution(output_scaling_factor):

# Check output size
factor = prod(output_scaling_factor)

assert list(out.shape) == [batch_size, 1] + [int(round(factor*s)) for s in size]
20 changes: 10 additions & 10 deletions neuralop/models/tests/test_uno.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import time
from ..uno import UNO
import torch
import pytest

def test_UNO():
@pytest.mark.parametrize('input_shape',
[(32,3,64,55),(32,3,100,105),(32,3,133,95)])
def test_UNO(input_shape):
horizontal_skips_map ={4:0,3:1}
model = UNO(3,3,5,uno_out_channels = [32,64,64,64,32], uno_n_modes= [[5,5],[5,5],[5,5],[5,5],[5,5]], uno_scalings= [[1.0,1.0],[0.5,0.5],[1,1],[1,1],[2,2]],\
horizontal_skips_map = horizontal_skips_map, n_layers = 5, domain_padding = 0.2)
horizontal_skips_map = horizontal_skips_map, n_layers = 5, domain_padding = 0.2, output_scaling_factor = 1)

t1 = time.time()
in_data = torch.randn(32,3,64,64)
out = model(in_data)
out = model(in_data)
in_data = torch.randn(input_shape)
out = model(in_data)
t = time.time() - t1
print(f'Output of size {out.shape} in {t}.')

for i in range(len(out.shape)):
assert in_data.shape[i] == out.shape[i]
loss = out.sum()
t1 = time.time()
loss.backward()
Expand All @@ -28,12 +30,10 @@ def test_UNO():


model = UNO(3,3,5,uno_out_channels = [32,64,64,64,32], uno_n_modes= [[5,5],[5,5],[5,5],[5,5],[5,5]], uno_scalings= [[1.0,1.0],[0.5,0.5],[1,1],[1,1],[2,2]],\
horizontal_skips_map = None, n_layers = 5, domain_padding = 0.2)
horizontal_skips_map = None, n_layers = 5, domain_padding = 0.2, output_scaling_factor = 1)

t1 = time.time()
in_data = torch.randn(32,3,64,64)
out = model(in_data)
out = model(in_data)
in_data = torch.randn(input_shape)
out = model(in_data)
t = time.time() - t1
print(f'Output of size {out.shape} in {t}.')
Expand Down
Loading

0 comments on commit 3c48850

Please sign in to comment.