Skip to content

Commit

Permalink
Set up example and exposed SFNO. Some cleanup in the naming
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Jun 8, 2023
1 parent 764f754 commit e085826
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 16 deletions.
8 changes: 4 additions & 4 deletions examples/plot_SFNO_swe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import TFNO
from neuralop.models import SFNO
from neuralop import Trainer
from neuralop.datasets import load_spherical_swe
from neuralop.utils import count_params
Expand All @@ -30,7 +30,7 @@
# %%
# We create a tensorized FNO model

model = TFNO(n_modes=(64, 64), hidden_channels=32, projection_channels=64, factorization='tucker', rank=0.42)
model = SFNO(n_modes=(64, 128), in_channels=3, out_channels=3, hidden_channels=32, projection_channels=64, factorization='dense')
model = model.to(device)

n_params = count_params(model)
Expand All @@ -41,8 +41,8 @@
# %%
#Create the optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=8e-3,
weight_decay=1e-4)
lr=8e-4,
weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)


Expand Down
1 change: 1 addition & 0 deletions neuralop/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .tfno import TFNO, TFNO1d, TFNO2d, TFNO3d
from .tfno import FNO, FNO1d, FNO2d, FNO3d
from .tfno import SFNO
from .uno import UNO
from .model_dispatcher import get_model
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from neuralop.models import FNO


import torch
from math import floor, prod
from tltorch import FactorizedTensor
import torch.nn as nn
import torch.nn.functional as F
from torch_harmonics import RealSHT, InverseRealSHT

from torch import nn
import torch
import itertools

from torch_harmonics import RealSHT, InverseRealSHT

import tensorly as tl
from tensorly.plugins import use_opt_einsum
Expand Down Expand Up @@ -171,7 +162,7 @@ def get_contract_fun(weight, implementation='reconstructed', separable=False):
raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')


class FactorizedSHTDConv(nn.Module):
class FactorizedSphericalConv(nn.Module):
def __init__(self, in_channels, out_channels, n_modes, incremental_n_modes=None, bias=True,
n_layers=1, separable=False, output_scaling_factor=None,
rank=0.5, factorization='cp', implementation='reconstructed',
Expand Down
3 changes: 3 additions & 0 deletions neuralop/models/tfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partialmethod
import torch
from .spectral_convolution import FactorizedSpectralConv
from .spherical_convolution import FactorizedSphericalConv
from .padding import DomainPadding
from .fno_block import FNOBlocks, resample

Expand Down Expand Up @@ -613,3 +614,5 @@ def partialclass(new_name, cls, *args, **kwargs):
TFNO1d = partialclass('TFNO1d', FNO1d, factorization='Tucker')
TFNO2d = partialclass('TFNO2d', FNO2d, factorization='Tucker')
TFNO3d = partialclass('TFNO3d', FNO3d, factorization='Tucker')

SFNO = partialclass('SFNO', FNO, factorization='dense', SpectralConv=FactorizedSpectralConv)

0 comments on commit e085826

Please sign in to comment.