forked from yzslab/gaussian-splatting-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
38 lines (33 loc) · 1.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# main.py
from internal.cli import CLI
from jsonargparse import lazy_instance
from internal.gaussian_splatting import GaussianSplatting
from internal.dataset import DataModule
from internal.callbacks import SaveGaussian
import lightning.pytorch.loggers
def cli_main():
cli = CLI(
GaussianSplatting,
DataModule,
seed_everything_default=42,
auto_configure_optimizers=False,
trainer_defaults={
"accelerator": "gpu",
"strategy": "auto",
"devices": 1,
# "logger": "TensorBoardLogger",
"num_sanity_val_steps": 1,
# "max_epochs": -1,
"max_steps": 30_000,
"use_distributed_sampler": False, # use custom ddp sampler
"enable_checkpointing": False,
"callbacks": [
lazy_instance(SaveGaussian),
],
},
save_config_kwargs={"overwrite": True},
)
# note: don't call fit!!
if __name__ == "__main__":
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block