Skip to content

Commit

Permalink
[feat] update eval
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw committed Apr 17, 2024
1 parent 23f6dd8 commit 1e6df55
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 59 deletions.
7 changes: 6 additions & 1 deletion opensora/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def parse_args(training=False):
# Inference
# ======================================================
if not training:
# output
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx")

# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")

# image/video
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
Expand Down Expand Up @@ -71,6 +74,8 @@ def merge_args(cfg, args, training=False):
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
if "sample_name" not in cfg:
cfg["sample_name"] = None
else:
# Training only
if args.data_path is not None:
Expand Down
22 changes: 14 additions & 8 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ def main():
print(cfg)

# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()

if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
if os.environ.get("WORLD_SIZE", None):
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()

if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
else:
use_dist = False
enable_sequence_parallelism = False

# ======================================================
Expand Down Expand Up @@ -91,6 +96,7 @@ def main():
# 4. inference
# ======================================================
sample_idx = 0
sample_name = cfg.sample_name if cfg.sample_name is not None else "sample"
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)

Expand All @@ -112,10 +118,10 @@ def main():
)
samples = vae.decode(samples.to(dtype))

if coordinator.is_master():
if not use_dist or coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1

Expand Down
16 changes: 16 additions & 0 deletions scripts/misc/sample.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set -x;

CUDA_VISIBLE_DEVICES=7
CMD="python scripts/inference.py configs/opensora-v1-1/inference/sample.py"
CKPT="~/lishenggui/epoch0-global_step8500"
OUTPUT="./outputs/samples_s1_8500"

# 1. image
# 1.1 1024x1024
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 1024 1024 --sample-name pixart_1024x1024_1

# 1.2 512x512

# 1.3 240x426

# 1.4 720p multi-resolution
65 changes: 15 additions & 50 deletions scripts/search_bs.py → scripts/misc/search_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import model_sharding
from opensora.utils.config_utils import merge_args, parse_configs
from opensora.utils.misc import (
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.misc import format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator, update_ema


Expand Down Expand Up @@ -66,34 +61,22 @@ class BColors:
def parse_configs():
parser = argparse.ArgumentParser()
parser.add_argument("config", help="model config file path")
parser.add_argument(
"-o", "--output", help="output config file path", default="output_config.py"
)
parser.add_argument("-o", "--output", help="output config file path", default="output_config.py")

parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument(
"--ckpt-path",
type=str,
help="path to model ckpt; will overwrite cfg.ckpt_path if specified",
)
parser.add_argument(
"--data-path", default=None, type=str, help="path to data csv", required=True
)
parser.add_argument("--data-path", default=None, type=str, help="path to data csv", required=True)
parser.add_argument("--warmup-steps", default=1, type=int, help="warmup steps")
parser.add_argument("--active-steps", default=1, type=int, help="active steps")
parser.add_argument(
"--base-resolution", default="240p", type=str, help="base resolution"
)
parser.add_argument("--base-resolution", default="240p", type=str, help="base resolution")
parser.add_argument("--base-frames", default=128, type=int, help="base frames")
parser.add_argument(
"--batch-size-start", default=2, type=int, help="batch size start"
)
parser.add_argument(
"--batch-size-end", default=256, type=int, help="batch size end"
)
parser.add_argument(
"--batch-size-step", default=2, type=int, help="batch size step"
)
parser.add_argument("--batch-size-start", default=2, type=int, help="batch size start")
parser.add_argument("--batch-size-end", default=256, type=int, help="batch size end")
parser.add_argument("--batch-size-step", default=2, type=int, help="batch size step")
args = parser.parse_args()
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args, training=True)
Expand All @@ -116,9 +99,7 @@ def main():
# ======================================================
cfg, args = parse_configs()
print(cfg)
assert (
cfg.dataset.type == "VariableVideoTextDataset"
), "Only VariableVideoTextDataset is supported"
assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported"

# ======================================================
# 2. runtime variables & colossalai launch
Expand Down Expand Up @@ -223,10 +204,7 @@ def main():
model_sharding(ema)

buckets = [
(res, f)
for res, d in cfg.bucket_config.items()
for f, (p, bs) in d.items()
if bs is not None and p > 0.0
(res, f) for res, d in cfg.bucket_config.items() for f, (p, bs) in d.items() if bs is not None and p > 0.0
]
output_bucket_cfg = deepcopy(cfg.bucket_config)
# find the base batch size
Expand All @@ -248,15 +226,11 @@ def main():
optimizer,
ema,
)
update_bucket_config_bs(
output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size
)
update_bucket_config_bs(output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size)
coordinator.print_on_master(
f"{BColors.OKBLUE}Base resolution: {args.base_resolution}, Base frames: {args.base_frames}, Batch size: {base_batch_size}, Base step time: {base_step_time}{BColors.ENDC}"
)
result_table = [
f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"
]
result_table = [f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"]
for resolution, frames in buckets:
try:
batch_size, step_time = benchmark(
Expand All @@ -280,9 +254,7 @@ def main():
f"{BColors.OKBLUE}Resolution: {resolution}, Frames: {frames}, Batch size: {batch_size}, Step time: {step_time}{BColors.ENDC}"
)
update_bucket_config_bs(output_bucket_cfg, resolution, frames, batch_size)
result_table.append(
f"{resolution}, {frames}, {batch_size}, {step_time:.2f}"
)
result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}")
except RuntimeError:
pass
result_table = "\n".join(result_table)
Expand Down Expand Up @@ -367,10 +339,7 @@ def run_step(bs) -> float:
raise RuntimeError("No valid batch size found")
if target_step_time is None:
# find the fastest batch size
throughputs = [
batch_size / step_time
for step_time, batch_size in zip(step_times, batch_sizes)
]
throughputs = [batch_size / step_time for step_time, batch_size in zip(step_times, batch_sizes)]
max_throughput = max(throughputs)
target_batch_size = batch_sizes[throughputs.index(max_throughput)]
step_time = step_times[throughputs.index(max_throughput)]
Expand Down Expand Up @@ -419,13 +388,9 @@ def train(
**dataloader_args,
)
dataloader_iter = iter(dataloader)
num_steps_per_epoch = (
dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
)
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()

assert (
num_steps_per_epoch >= total_steps
), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
assert num_steps_per_epoch >= total_steps, f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
duration = 0
# this is essential for the first iteration after OOM
optimizer._grad_store.reset_all_gradients()
Expand Down

0 comments on commit 1e6df55

Please sign in to comment.