Skip to content

Commit

Permalink
[wip] inference refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw committed May 10, 2024
1 parent 2b6a897 commit 38b18ae
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 17 deletions.
5 changes: 3 additions & 2 deletions configs/opensora-v1-2/inference/sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
image_size = (240, 426)
num_frames = 34
num_frames = 204
fps = 24
frame_interval = 1

Expand All @@ -20,6 +20,7 @@
vae = dict(
type="VideoAutoencoderPipeline",
from_pretrained="pretrained_models/vae-v3",
scale=2.5,
micro_frame_size=17,
vae_2d=dict(
type="VideoAutoencoderKL",
Expand All @@ -44,5 +45,5 @@
use_discrete_timesteps=False,
use_timestep_transform=True,
num_sampling_steps=30,
cfg_scale=4.5,
cfg_scale=7.0,
)
20 changes: 10 additions & 10 deletions configs/opensora-v1-2/train/stage1-gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
# ---
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
"512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
# ---
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
# ---
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
"1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# ---
"1080p": {1: (0.1, 10)},
# ---
"2048": {1: (0.1, 5)},
# "512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
# # ---
# "480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
# # ---
# "720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# "1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# # ---
# "1080p": {1: (0.1, 10)},
# # ---
# "2048": {1: (0.1, 5)},
}

grad_checkpoint = True
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ transformers
wandb
rotary_embedding_torch
pandarallel
gradio
spaces
17 changes: 12 additions & 5 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@


def main():
torch.set_grad_enabled(False)
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)

# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.dtype)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

verbose = cfg.get("verbose", 2)
print(cfg)

Expand All @@ -41,11 +52,7 @@ def main():
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)

set_random_seed(seed=cfg.seed)
prompts = cfg.prompt

Expand Down
1 change: 1 addition & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def main():
# == parse configs ==
cfg = parse_configs(training=True)

# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
Expand Down
1 change: 1 addition & 0 deletions scripts/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def main():
# == parse configs ==
cfg = parse_configs(training=True)

# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
Expand Down

0 comments on commit 38b18ae

Please sign in to comment.