Skip to content

Commit

Permalink
remove unused arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
zijieli-Jlee authored Nov 3, 2023
1 parent 45c03de commit 981ba1a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 27 deletions.
4 changes: 2 additions & 2 deletions examples/darcy2d_fact.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import logging, pickle, h5py

from libs.factorization_module import FABlock2D
from libs.positional_encoding_module import SirenNet, GaussianFourierFeatureTransform, Sine
from libs.positional_encoding_module import GaussianFourierFeatureTransform
from libs.basics import PreNorm, MLP, masked_instance_norm
from utils import Trainer, dict2namespace, index_points, load_checkpoint, save_checkpoint, ensure_dir
import yaml
from torch.optim.lr_scheduler import StepLR, OneCycleLR
from torch.optim.lr_scheduler import OneCycleLR
from loss_fn import rel_l2_loss

from matplotlib import pyplot as plt
Expand Down
6 changes: 3 additions & 3 deletions examples/threed_smokesim_fact_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

from libs.factorization_module import FABlock3D
from libs.basics import PreNorm, MLP
from nn_module.positional_encoding_module import SirenNet, GaussianFourierFeatureTransform, Sine
from nn_module.positional_encoding_module import GaussianFourierFeatureTransform
from utils import Trainer, dict2namespace, index_points, load_checkpoint, save_checkpoint, ensure_dir, Timer
import yaml
from torch.optim.lr_scheduler import StepLR, OneCycleLR
from loss_fn import rel_l2_loss, rel_l1_loss
from torch.optim.lr_scheduler import OneCycleLR
from loss_fn import rel_l2_loss

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
Expand Down
12 changes: 3 additions & 9 deletions examples/turb_ns2d_fact_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@


from libs.factorization_module import FABlock2D
from libs.positional_encoding_module import SirenNet, GaussianFourierFeatureTransform, Sine
from libs.positional_encoding_module import GaussianFourierFeatureTransform
from libs.basics import PreNorm, MLP, masked_instance_norm
from utils import Trainer, dict2namespace, index_points, load_checkpoint, save_checkpoint, ensure_dir
import yaml
from torch.optim.lr_scheduler import StepLR, OneCycleLR
from loss_fn import rel_l2_loss, rel_l1_loss
from torch.optim.lr_scheduler import OneCycleLR
from loss_fn import rel_l2_loss

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
Expand Down Expand Up @@ -77,18 +77,14 @@ def __init__(self,
self.out_dim = config.out_dim
self.out_tw = config.out_time_window

#self.num_latent = config.num_latent
#self.latent_dim = config.latent_dim # latent bottleneck dimension
self.dim = config.dim # dimension of the transformer
self.depth = config.depth # depth of the encoder transformer
self.dim_head = config.dim_head
self.reducer = config.reducer

self.heads = config.heads

self.pos_in_dim = config.pos_in_dim
self.pos_out_dim = config.pos_out_dim
self.positional_embedding = config.positional_embedding
self.kernel_multiplier = config.kernel_multiplier
self.latent_multiplier = config.latent_multiplier
self.latent_dim = int(self.dim * self.latent_multiplier)
Expand All @@ -99,7 +95,6 @@ def __init__(self,

# assume input is b c t h w d
self.encoder = FactorizedTransformer(self.dim, self.dim_head, self.heads, self.dim, self.depth,

kernel_multiplier=self.kernel_multiplier)
self.expand_latent = nn.Linear(self.dim, self.latent_dim, bias=False)
self.latent_time_emb = nn.Parameter(torch.randn(1, self.max_latent_steps,
Expand All @@ -118,7 +113,6 @@ def __init__(self,
nn.GELU(),
nn.Conv1d(self.dim // 2, self.out_dim, kernel_size=1, stride=1, padding=0, bias=True)
)
# self.decoder = Reconstruct3D(self.dim, self.dim_head, self.heads, self.out_dim)

def forward(self,
u,
Expand Down
21 changes: 8 additions & 13 deletions examples/turb_ns3d_fact_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch.utils.data import Dataset, DataLoader, TensorDataset
import logging, pickle, h5py

from libs.positional_encoding_module import GaussianFourierFeatureTransform, SirenNet
from libs.positional_encoding_module import GaussianFourierFeatureTransform
from libs.factorization_module import FABlock3D
from libs.basics import PreNorm, MLP
from utils import Trainer, dict2namespace, index_points, load_checkpoint, save_checkpoint, ensure_dir
import yaml
from torch.optim.lr_scheduler import StepLR, OneCycleLR
from loss_fn import rel_l2_loss, rel_l1_loss
from torch.optim.lr_scheduler import OneCycleLR
from loss_fn import rel_l2_loss

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
Expand Down Expand Up @@ -101,8 +101,6 @@ def __init__(self,
nn.Conv2d(self.dim // 2, self.dim, kernel_size=(self.in_tw, 1), stride=1, padding=0, bias=False),
)

self.time_embedding = SirenNet(1, self.dim, self.dim, 3, normalize_input=False)

# assume input is b c t h w d
self.encoder = FactorizedTransformer(self.dim, self.dim_head, self.heads, self.dim, self.depth,
kernel_multiplier=self.kernel_multiplier,)
Expand All @@ -129,7 +127,6 @@ def __init__(self,
def forward(self,
u,
pos_lst,
t_coord, # [b,]
latent_steps,
):
# u: b c t h w d
Expand All @@ -140,8 +137,6 @@ def forward(self,
u = self.to_in(u)
u = rearrange(u, 'b c 1 (nx ny nz) -> b nx ny nz c', nx=nx, ny=ny, nz=nz)

t_emb = self.time_embedding(t_coord.unsqueeze(-1))
u = u + t_emb.view(b, 1, 1, 1, -1)
u = self.encoder(u, pos_lst)
u = self.expand_latent(u)
u_lst = []
Expand Down Expand Up @@ -493,19 +488,19 @@ def step_fn(self, data,

# t_step = self.curriculum_scheduler.get_value()
if current_latent_steps is None:
y_hat = self.model(x, pos_lst, t_coord[:, 0], 1)
y_hat = self.model(x, pos_lst, 1)
y = y[:, 0:1]
elif not pushforward:
y_hat = self.model(x, pos_lst, t_coord[:, 0], current_latent_steps)
y_hat = self.model(x, pos_lst, current_latent_steps)
y = y[:, 0:current_latent_steps]
else:
with torch.no_grad():
x_hat = self.model(x, pos_lst, t_coord[:, 0], current_latent_steps)
x_hat = self.model(x, pos_lst, current_latent_steps)
x = torch.cat([x[:, :, current_latent_steps:],
rearrange(
x_hat,
'b t h w d c -> b c t h w d')], dim=2)
y_hat = self.model(x.detach(), pos_lst, t_coord[:, current_latent_steps], current_latent_steps)
y_hat = self.model(x.detach(), pos_lst, current_latent_steps)
y = y[:, current_latent_steps:current_latent_steps * 2]
# denormalize
y_hat = self.denormalize(y_hat)
Expand All @@ -521,7 +516,7 @@ def step_fn(self, data,
y_hat = torch.zeros_like(y) # b, t, h, w, c
for i in range(y.shape[1] // self.max_latent_steps):
y_hat[:, i * self.max_latent_steps:(i + 1) * self.max_latent_steps] =\
self.model.forward(x, pos_lst, t_coord[:, i], latent_steps=self.max_latent_steps)
self.model.forward(x, pos_lst, latent_steps=self.max_latent_steps)
x = torch.cat([x[:, :, self.max_latent_steps:],
rearrange(
y_hat[:, i * self.max_latent_steps:(i + 1) * self.max_latent_steps],
Expand Down

0 comments on commit 981ba1a

Please sign in to comment.