Skip to content

Commit

Permalink
[feat] notebook for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw committed May 14, 2024
1 parent 5929cae commit f795a09
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 5 deletions.
352 changes: 352 additions & 0 deletions notebooks/inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference for OpenSora"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define global variables."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# global variables\n",
"ROOT = \"..\"\n",
"cfg_path = f\"{ROOT}/configs/opensora-v1-2/inference/sample.py\"\n",
"ckpt_path = \"/home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/\"\n",
"vae_path = f\"{ROOT}/pretrained_models/vae-pipeline\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import necessary libraries and load the models."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pprint import pformat\n",
"\n",
"import colossalai\n",
"import torch\n",
"import torch.distributed as dist\n",
"from colossalai.cluster import DistCoordinator\n",
"from mmengine.runner import set_random_seed\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from opensora.acceleration.parallel_states import set_sequence_parallel_group\n",
"from opensora.datasets import save_sample, is_img\n",
"from opensora.datasets.aspect import get_image_size, get_num_frames\n",
"from opensora.models.text_encoder.t5 import text_preprocessing\n",
"from opensora.registry import MODELS, SCHEDULERS, build_module\n",
"from opensora.utils.config_utils import read_config\n",
"from opensora.utils.inference_utils import (\n",
" append_generated,\n",
" apply_mask_strategy,\n",
" collect_references_batch,\n",
" extract_json_from_prompts,\n",
" extract_prompts_loop,\n",
" get_save_path_name,\n",
" load_prompts,\n",
" prepare_multi_resolution_info,\n",
")\n",
"from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"42"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.set_grad_enabled(False)\n",
"\n",
"# == parse configs ==\n",
"cfg = read_config(cfg_path)\n",
"cfg.model.from_pretrained = ckpt_path\n",
"cfg.vae.from_pretrained = vae_path\n",
"\n",
"# == device and dtype ==\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"cfg_dtype = cfg.get(\"dtype\", \"fp32\")\n",
"assert cfg_dtype in [\"fp16\", \"bf16\", \"fp32\"], f\"Unknown mixed precision {cfg_dtype}\"\n",
"dtype = to_torch_dtype(cfg.get(\"dtype\", \"bf16\"))\n",
"torch.backends.cuda.matmul.allow_tf32 = True\n",
"torch.backends.cudnn.allow_tf32 = True\n",
"\n",
"set_random_seed(seed=cfg.get(\"seed\", 1024))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "29ca38c42a38453aa65784e1ee89a61a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# == build text-encoder and vae ==\n",
"text_encoder = build_module(cfg.text_encoder, MODELS, device=device)\n",
"vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()\n",
"\n",
"# == build diffusion model ==\n",
"input_size = (None, None, None)\n",
"latent_size = vae.get_latent_size(input_size)\n",
"model = (\n",
" build_module(\n",
" cfg.model,\n",
" MODELS,\n",
" input_size=latent_size,\n",
" in_channels=vae.out_channels,\n",
" caption_channels=text_encoder.output_dim,\n",
" model_max_length=text_encoder.model_max_length,\n",
" )\n",
" .to(device, dtype)\n",
" .eval()\n",
")\n",
"text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define inference function."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"start_idx = 0\n",
"multi_resolution = cfg.get(\"multi_resolution\", None)\n",
"batch_size = cfg.get(\"batch_size\", 1)\n",
"\n",
"\n",
"def inference(\n",
" prompts=cfg.get(\"prompt\", None),\n",
" image_size=None,\n",
" num_frames=None,\n",
" resolution=None,\n",
" aspect_ratio=None,\n",
" mask_strategy=None,\n",
" reference_path=None,\n",
" num_sampling_steps=None,\n",
" cfg_scale=None,\n",
" seed=None,\n",
" fps=cfg.fps,\n",
" num_sample=cfg.get(\"num_sample\", 1),\n",
" loop=cfg.get(\"loop\", 1),\n",
" condition_frame_length=cfg.get(\"condition_frame_length\", 5),\n",
" align=cfg.get(\"align\", None),\n",
" save_dir=os.path.join(ROOT, cfg.save_dir),\n",
" sample_name=cfg.get(\"sample_name\", None),\n",
" prompt_as_path=cfg.get(\"prompt_as_path\", False),\n",
"):\n",
" global start_idx\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" if seed is not None:\n",
" set_random_seed(seed=seed)\n",
" if not isinstance(prompts, list):\n",
" prompts = [prompts]\n",
" if mask_strategy is None:\n",
" mask_strategy = [\"\"] * len(prompts)\n",
" if reference_path is None:\n",
" reference_path = [\"\"] * len(prompts)\n",
" save_fps = cfg.fps // cfg.get(\"frame_interval\", 1)\n",
" if num_sampling_steps is not None:\n",
" cfg.scheduler[\"num_sampling_steps\"] = num_sampling_steps\n",
" if cfg_scale is not None:\n",
" cfg.scheduler[\"scale\"] = cfg_scale\n",
" scheduler = build_module(cfg.scheduler, SCHEDULERS)\n",
" ret_path = []\n",
"\n",
" # == prepare video size ==\n",
" if image_size is None:\n",
" assert (\n",
" resolution is not None and aspect_ratio is not None\n",
" ), \"resolution and aspect_ratio must be provided if image_size is not provided\"\n",
" image_size = get_image_size(resolution, aspect_ratio)\n",
" num_frames = get_num_frames(cfg.num_frames)\n",
" input_size = (num_frames, *image_size)\n",
" latent_size = vae.get_latent_size(input_size)\n",
"\n",
" # == Iter over all samples ==\n",
" for i in tqdm(range(0, len(prompts), batch_size)):\n",
" # == prepare batch prompts ==\n",
" batch_prompts = prompts[i : i + batch_size]\n",
" ms = mask_strategy[i : i + batch_size]\n",
" refs = reference_path[i : i + batch_size]\n",
"\n",
" batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)\n",
" refs = collect_references_batch(refs, vae, image_size)\n",
"\n",
" # == multi-resolution info ==\n",
" model_args = prepare_multi_resolution_info(\n",
" multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype\n",
" )\n",
"\n",
" # == Iter over number of sampling for one prompt ==\n",
" for k in range(num_sample):\n",
" # == prepare save paths ==\n",
" save_paths = [\n",
" get_save_path_name(\n",
" save_dir,\n",
" sample_name=sample_name,\n",
" sample_idx=start_idx + idx,\n",
" prompt=batch_prompts[idx],\n",
" prompt_as_path=prompt_as_path,\n",
" num_sample=num_sample,\n",
" k=k,\n",
" )\n",
" for idx in range(len(batch_prompts))\n",
" ]\n",
"\n",
" # NOTE: Skip if the sample already exists\n",
" # This is useful for resuming sampling VBench\n",
" if prompt_as_path and all_exists(save_paths):\n",
" continue\n",
"\n",
" # == Iter over loop generation ==\n",
" video_clips = []\n",
" for loop_i in range(loop):\n",
" batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)\n",
" batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]\n",
"\n",
" # == loop ==\n",
" if loop_i > 0:\n",
" refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)\n",
"\n",
" # == sampling ==\n",
" z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)\n",
" masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)\n",
" samples = scheduler.sample(\n",
" model,\n",
" text_encoder,\n",
" z=z,\n",
" prompts=batch_prompts_cleaned,\n",
" device=device,\n",
" additional_args=model_args,\n",
" progress=False,\n",
" mask=masks,\n",
" )\n",
" samples = vae.decode(samples.to(dtype), num_frames=num_frames)\n",
" video_clips.append(samples)\n",
"\n",
" # == save samples ==\n",
" if is_main_process():\n",
" for idx, batch_prompt in enumerate(batch_prompts):\n",
" save_path = save_paths[idx]\n",
" video = [video_clips[i][idx] for i in range(loop)]\n",
" for i in range(1, loop):\n",
" video[i] = video[i][:, condition_frame_length:]\n",
" video = torch.cat(video, dim=1)\n",
" path = save_sample(\n",
" video,\n",
" fps=save_fps,\n",
" save_path=save_path,\n",
" verbose=False,\n",
" )\n",
" ret_path.append(path)\n",
" start_idx += len(batch_prompts)\n",
" return ret_path"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Video, Image, display\n",
"\n",
"def display_results(paths):\n",
" for path in paths:\n",
" if is_img(path):\n",
" display(Image(path))\n",
" else:\n",
" display(Video(path, embed=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"paths = inference(\n",
" [\"a man.\", \"a woman\"],\n",
" resolution=\"240p\",\n",
" aspect_ratio=\"1:1\",\n",
" num_frames=\"1x\",\n",
" num_sampling_steps=30,\n",
" cfg_scale=7.0,\n",
")\n",
"display_results(paths)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "opensora",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion opensora/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dataloader import prepare_dataloader, prepare_variable_dataloader
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, save_sample
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
15 changes: 15 additions & 0 deletions opensora/datasets/aspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,18 @@ def get_image_size(resolution, ar_ratio):
rs_dict = ASPECT_RATIOS[resolution][1]
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
return rs_dict[ar_key]


NUM_FRAMES_MAP = {
"1x": 51,
"2x": 102,
"4x": 204,
"8x": 408,
}


def get_num_frames(num_frames):
if num_frames in NUM_FRAMES_MAP:
return NUM_FRAMES_MAP[num_frames]
else:
return int(num_frames)
10 changes: 10 additions & 0 deletions opensora/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
)


def is_img(path):
ext = os.path.splitext(path)[-1].lower()
return ext in IMG_EXTENSIONS


def is_vid(path):
ext = os.path.splitext(path)[-1].lower()
return ext in VID_EXTENSIONS


def is_url(url):
return re.match(regex, url) is not None

Expand Down
Loading

0 comments on commit f795a09

Please sign in to comment.