forked from neuraloperator/neuraloperator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_darcy.py
153 lines (130 loc) · 5.69 KB
/
train_darcy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import wandb
import sys
from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig
from neuralop import get_model
from neuralop import Trainer
from neuralop.training import setup
from neuralop.datasets import load_darcy_flow_small
from neuralop.utils import get_wandb_api_key, count_params
from neuralop import LpLoss, H1Loss
from torch.nn.parallel import DistributedDataParallel as DDP
# Read the configuration
config_name = 'default'
pipe = ConfigPipeline([YamlConfig('./darcy_config.yaml', config_name='default', config_folder='../config'),
ArgparseConfig(infer_types=True, config_name=None, config_file=None),
YamlConfig(config_folder='../config')
])
config = pipe.read_conf()
config_name = pipe.steps[-1].config_name
#Set-up distributed communication, if using
device, is_logger = setup(config)
#Set up WandB logging
if config.wandb.log and is_logger:
wandb.login(key=get_wandb_api_key())
if config.wandb.name:
wandb_name = config.wandb.name
else:
wandb_name = '_'.join(
f'{var}' for var in [config_name, config.tfno2d.n_layers,
config.tfno2d.hidden_channels,
config.tfno2d.n_modes_width,
config.tfno2d.n_modes_height,
config.tfno2d.factorization,
config.tfno2d.rank,
config.patching.levels,
config.patching.padding])
wandb.init(config=config, name=wandb_name, group=config.wandb.group,
project=config.wandb.project, entity=config.wandb.entity)
if config.wandb.sweep:
for key in wandb.config.keys():
config.params[key] = wandb.config[key]
# Make sure we only print information when needed
config.verbose = config.verbose and is_logger
#Print config to screen
if config.verbose and is_logger:
pipe.log()
sys.stdout.flush()
# Loading the Darcy flow dataset
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
n_train=config.data.n_train, batch_size=config.data.batch_size,
positional_encoding=config.data.positional_encoding,
test_resolutions=config.data.test_resolutions, n_tests=config.data.n_tests, test_batch_sizes=config.data.test_batch_sizes,
encode_input=config.data.encode_input, encode_output=config.data.encode_output,
)
model = get_model(config)
model = model.to(device)
#Use distributed data parallel
if config.distributed.use_distributed:
model = DDP(model,
device_ids=[device.index],
output_device=device.index,
static_graph=True)
#Log parameter count
if is_logger:
n_params = count_params(model)
if config.verbose:
print(f'\nn_params: {n_params}')
sys.stdout.flush()
if config.wandb.log:
to_log = {'n_params': n_params}
if config.n_params_baseline is not None:
to_log['n_params_baseline'] = config.n_params_baseline,
to_log['compression_ratio'] = config.n_params_baseline/n_params,
to_log['space_savings'] = 1 - (n_params/config.n_params_baseline)
wandb.log(to_log)
wandb.watch(model)
#Create the optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=config.opt.learning_rate,
weight_decay=config.opt.weight_decay)
if config.opt.scheduler == 'ReduceLROnPlateau':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.opt.gamma, patience=config.opt.scheduler_patience, mode='min')
elif config.opt.scheduler == 'CosineAnnealingLR':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.opt.scheduler_T_max)
elif config.opt.scheduler == 'StepLR':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=config.opt.step_size,
gamma=config.opt.gamma)
else:
raise ValueError(f'Got {config.opt.scheduler=}')
# Creating the losses
l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
if config.opt.training_loss == 'l2':
train_loss = l2loss
elif config.opt.training_loss == 'h1':
train_loss = h1loss
else:
raise ValueError(f'Got training_loss={config.opt.training_loss} but expected one of ["l2", "h1"]')
eval_losses={'h1': h1loss, 'l2': l2loss}
if config.verbose and is_logger:
print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
print(f'\n### Beginning Training...\n')
sys.stdout.flush()
trainer = Trainer(model, n_epochs=config.opt.n_epochs,
device=device,
amp_autocast=config.opt.amp_autocast,
mg_patching_levels=config.patching.levels,
mg_patching_padding=config.patching.padding,
mg_patching_stitching=config.patching.stitching,
wandb_log=config.wandb.log,
log_test_interval=config.wandb.log_test_interval,
log_output=config.wandb.log_output,
use_distributed=config.distributed.use_distributed,
verbose=config.verbose and is_logger)
trainer.train(train_loader, test_loaders,
output_encoder,
model,
optimizer,
scheduler,
regularizer=False,
training_loss=train_loss,
eval_losses=eval_losses)
if config.wandb.log and is_logger:
wandb.finish()