Skip to content

Commit

Permalink
[feat] update bs search and loss eval
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw committed May 15, 2024
1 parent fde97d6 commit 5657952
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 29 deletions.
53 changes: 37 additions & 16 deletions configs/opensora-v1-2/misc/bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,46 @@
base = ("512", "408")
base_step_time = 12
bucket_config = {
# "144p": {1: (100, 50), 51: (30, 20), 102: (20, 10), 204: (8, 4), 408: (4, 4)},
# # ---
# "240p": {1: (100, 20), 51: (24, 5), 102: (12, 4), 204: (4, 2), 408: (2, 1)},
"144p": {
1: (100, 300),
51: (30, 100),
102: (20, 100),
204: (8, 20),
408: (4, 10),
},
# ---
"240p": {
1: (100, 100),
51: (24, 10),
102: (12, 10),
204: (4, 8),
408: (2, 8),
},
# ---
# "512": {
# 1: (141, 0),
# 51: (8, 0),
# 102: (4, 0),
# 204: (2, 0),
# 408: (1, 0),
# },
# ---
"512": {
# 1: (141, 0),
51: (9, 4),
102: (6, 2),
204: (2, 1),
# 408: (1, 0),
"480p": {
1: (50, 50),
51: (6, 6),
102: (3, 3),
204: (1, 2),
},
# ---
# "480p": {1: (40, 10), 51: (6, 2), 102: (3, 2), 204: (1, 1)},
# # ---
# "1024": {1: (20, 10), 51: (2, 1), 102: (1, 1)},
# # ---
# "1080p": {1: (10, 5)},
# # ---
# "2048": {1: (5, 2)},
"1024": {
1: (20, 20),
51: (2, 2),
102: (1, 1),
},
# ---
"1080p": {1: (10, 10)},
# ---
"2048": {1: (5, 5)},
}

# Acceleration settings
Expand Down
18 changes: 9 additions & 9 deletions configs/opensora-v1-2/misc/eval_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,20 @@
transform_name="resize_crop",
)

# just occupy the space.... actually in evaluation we will create dataset for different resolutions
bucket_config = { # 20s/it
"144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)},
bucket_config = {
"144p": {1: (None, 100), 51: (None, 30), 102: (None, 20), 204: (None, 8), 408: (None, 4)},
# ---
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
"240p": {1: (None, 100), 51: (None, 24), 102: (None, 12), 204: (None, 4), 408: (None, 2)},
# ---
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
"360p": {1: (None, 60), 51: (None, 12), 102: (None, 6), 204: (None, 2), 408: (None, 1)},
# ---
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
"480p": {1: (None, 40), 51: (None, 6), 102: (None, 3), 204: (None, 1)},
# ---
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
"720p": {1: (None, 20), 51: (None, 2), 102: (None, 1)},
# ---
"1080p": {1: (0.1, 10)},
"1080p": {1: (None, 10)},
# ---
"2048": {1: (0.1, 5)},
"2048": {1: (None, 5)},
}

# Model settings
Expand All @@ -39,6 +38,7 @@
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=17,
micro_batch_size=4,
local_files_only=True,
)
text_encoder = dict(
type="t5",
Expand Down
10 changes: 10 additions & 0 deletions eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ bash eval/sample.sh /path/to/ckpt -2a
bash eval/launch.sh /path/to/ckpt
```

## Rectified Flow Loss

```bash
CUDA_VISIBLE_DEVICES=2 torchrun --standalone --nproc_per_node 1 scripts/misc/eval_loss.py configs/opensora-v1-2/misc/eval_loss.py --data-path /mnt/nfs-207/sora_data/meta/img_1k.csv --ckpt-path /home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/

CUDA_VISIBLE_DEVICES=3 torchrun --standalone --nproc_per_node 1 scripts/misc/eval_loss.py configs/opensora-v1-2/misc/eval_loss.py --data-path /mnt/nfs-207/sora_data/meta/vid_100.csv --ckpt-path /home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/

CUDA_VISIBLE_DEVICES=3 torchrun --standalone --nproc_per_node 1 scripts/misc/eval_loss.py configs/opensora-v1-2/misc/eval_loss.py --data-path /mnt/nfs-207/sora_data/meta/vid_100.csv --ckpt-path /home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/ --resolution 720p
```

## VBench

[VBench](https://github.com/Vchitect/VBench) is a benchmark for short text to video generation. We provide a script for easily generating samples required by VBench.
Expand Down
2 changes: 1 addition & 1 deletion opensora/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse_args(training=False):
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")

# ======================================================
# Inference
Expand All @@ -50,7 +51,6 @@ def parse_args(training=False):
parser.add_argument("--fps", default=None, type=int, help="fps")
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
parser.add_argument("--frame-interval", default=None, type=int, help="frame interval")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
parser.add_argument("--aspect-ratio", default=None, type=float, help="aspect ratio")

# hyperparameters
Expand Down
5 changes: 4 additions & 1 deletion scripts/misc/eval_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def main():
# ======================================================
# start evaluation, prepare a dataset everytime in the loop
bucket_config = cfg.bucket_config
if cfg.get("resolution", None) is not None:
bucket_config = {cfg.resolution: bucket_config[cfg.resolution]}
assert bucket_config is not None, "bucket_config is required for evaluation"
logger.info("Evaluating bucket_config: %s", bucket_config)

def build_dataset(resolution, num_frames, batch_size):
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
Expand Down Expand Up @@ -118,7 +121,7 @@ def build_dataset(resolution, num_frames, batch_size):
continue

evaluation_t_losses = []
for t in torch.linspace(0, scheduler.num_timesteps, cfg.get("num_eval_timesteps", 10)):
for t in torch.linspace(0, scheduler.num_timesteps, cfg.get("num_eval_timesteps", 10) + 2)[1:-1]:
loss_t = 0.0
num_samples = 0
dataloader_iter = iter(dataloader)
Expand Down
7 changes: 5 additions & 2 deletions scripts/misc/search_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,13 @@ def benchmark(resolution, num_frames, lower_bound, upper_bound, ref_step_time=No
return target_batch_size, target_step_time

# == build bucket ==
output_bucket_cfg = deepcopy(cfg.bucket_config)
bucket_config = cfg.bucket_config
output_bucket_cfg = deepcopy(bucket_config)
if cfg.get("resolution", None) is not None:
bucket_config = {cfg.resolution: bucket_config[cfg.resolution]}
buckets = {
(resolution, num_frames): (max(guess_bs - variance, 1), guess_bs + variance)
for resolution, t_bucket in cfg.bucket_config.items()
for resolution, t_bucket in bucket_config.items()
for num_frames, (guess_bs, variance) in t_bucket.items()
}

Expand Down

0 comments on commit 5657952

Please sign in to comment.