Skip to content

Commit

Permalink
format and some fix (hpcaitech#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengzangw authored Mar 30, 2024
1 parent 917a026 commit 9da2087
Show file tree
Hide file tree
Showing 25 changed files with 94 additions and 166 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The Open-Sora project welcomes any constructive contribution from the community

## Development Environment Setup

To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.
To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.

You can refer to the [Installation Section](./README.md#installation) and replace `pip install -v .` with `pip install -v -e .`.

Expand Down
4 changes: 1 addition & 3 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
for the public:
wiki.creativecommons.org/Considerations_for_licensees

=======================================================================
Expand Down Expand Up @@ -677,5 +677,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ conda create -n opensora python=3.10
conda activate opensora

# install torch
# the command below is for CUDA 12.1, choose install commands from
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip install torch torchvision

Expand Down
2 changes: 1 addition & 1 deletion configs/opensora-v1-1/train/Vx360p.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
}

# Define acceleration
num_workers = 0
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
Expand Down
2 changes: 1 addition & 1 deletion configs/opensora/inference-long/16x512x512-extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
prompt_path = None
prompt = [
"Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
]

loop = 10
Expand Down
4 changes: 2 additions & 2 deletions configs/opensora/inference/16x256x256.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
type="iddpm",
num_sampling_steps=100,
cfg_scale=7.0,
cfg_channel=3, # or None
cfg_channel=3, # or None
)
dtype = "fp16"

# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
prompt = None # prompt has higher priority than prompt_path

# Others
batch_size = 1
Expand Down
2 changes: 1 addition & 1 deletion configs/opensora/inference/16x512x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
time_scale=1.0,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL"
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(
type="VideoAutoencoderKL",
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_CN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
conda create -n opensora python=3.10

# install torch
# the command below is for CUDA 12.1, choose install commands from
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip3 install torch torchvision

Expand Down
1 change: 0 additions & 1 deletion opensora/datasets/aspect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math


# Ours


Expand Down
16 changes: 4 additions & 12 deletions opensora/datasets/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,11 @@ def __init__(self, bucket_config):
# wrap config with OrderedDict
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(
bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True
)
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(
bucket_config[key].keys(), key=lambda x: x, reverse=True
)
bucket_probs[key] = OrderedDict(
{k: bucket_config[key][k][0] for k in bucket_time_names}
)
bucket_bs[key] = OrderedDict(
{k: bucket_config[key][k][1] for k in bucket_time_names}
)
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})

# first level: HW
num_bucket = 0
Expand Down
33 changes: 0 additions & 33 deletions opensora/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from .bucket import Bucket
from .sampler import DistributedVariableVideoSampler, VariableVideoBatchSampler


Expand Down Expand Up @@ -98,38 +97,6 @@ def seed_worker(worker_id):
)


class _VariableVideoBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config):
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.bucket = Bucket(buckect_config)
self.frame_interval = self.dataset.frame_interval
self.bucket.info_bucket(self.dataset, self.frame_interval)

def __iter__(self):
for idx in self.sampler:
T, H, W = self.dataset.get_data_info(idx)
bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval)
if bucket_id is None:
continue
rT, rH, rW = self.bucket.get_thw(bucket_id)
self.dataset.set_data_info(idx, rT, rH, rW)
buffer = self.bucket[bucket_id]
buffer.append(idx)
if len(buffer) >= self.bucket.get_batch_size(bucket_id):
yield buffer
self.bucket.set_empty(bucket_id)

for k1, v1 in self.bucket.bucket.items():
for k2, v2 in v1.items():
for k3, buffer in v2.items():
if len(buffer) > 0 and not self.drop_last:
yield buffer
self.bucket.set_empty((k1, k2, k3))


def prepare_variable_dataloader(
dataset,
batch_size,
Expand Down
11 changes: 10 additions & 1 deletion opensora/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ def __init__(
"video": get_transforms_video(transform_name, image_size),
}

def _print_data_number(self):
num_videos = 0
num_images = 0
for path in self.data["path"]:
if self.get_type(path) == "video":
num_videos += 1
else:
num_images += 1
print(f"Dataset contains {num_videos} videos and {num_images} images.")

def get_type(self, path):
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
Expand Down Expand Up @@ -148,7 +158,6 @@ def getitem(self, index):
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}

def __getitem__(self, index):
return self.getitem(index)
for _ in range(10):
try:
return self.getitem(index)
Expand Down
39 changes: 18 additions & 21 deletions opensora/datasets/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import warnings
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pprint import pprint
from typing import Iterator, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -43,9 +44,7 @@ def __iter__(self) -> Iterator[Tuple[tuple, int]]:
# group by bucket
for i in range(len(self.dataset)):
t, h, w = self.dataset.get_data_info(i)
bucket_id = self.bucket.get_bucket_id(
t, h, w, self.dataset.frame_interval, g
)
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
if bucket_id is None:
continue
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
Expand All @@ -56,12 +55,8 @@ def __iter__(self) -> Iterator[Tuple[tuple, int]]:
# shuffle
if self.shuffle:
# sort buckets
bucket_indices = torch.randperm(
len(bucket_sample_dict), generator=g
).tolist()
bucket_order = {
k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)
}
bucket_indices = torch.randperm(len(bucket_sample_dict), generator=g).tolist()
bucket_order = {k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)}
# sort samples in each bucket
for k, v in bucket_sample_dict.items():
sample_indices = torch.randperm(len(v), generator=g).tolist()
Expand Down Expand Up @@ -90,11 +85,7 @@ def __iter__(self) -> Iterator[Tuple[tuple, int]]:
if self.verbose:
self._print_bucket_info(bucket_sample_dict)
if self.shuffle:
bucket_sample_dict = OrderedDict(
sorted(
bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]
)
)
bucket_sample_dict = OrderedDict(sorted(bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]))
# iterate
found_last_bucket = self.last_bucket_id is None
for k, v in bucket_sample_dict.items():
Expand Down Expand Up @@ -126,13 +117,21 @@ def __len__(self) -> int:
def _print_bucket_info(self, bucket_sample_dict: dict) -> None:
total_samples = 0
num_dict = {}
num_aspect_dict = defaultdict(int)
num_hwt_dict = defaultdict(int)
for k, v in bucket_sample_dict.items():
size = len(v) * self.num_replicas
total_samples += size
num_dict[k] = size
print(
f"Total training samples: {total_samples}, num buckets: {len(num_dict)}, bucket samples: {num_dict}"
)
num_aspect_dict[k[-1]] += size
num_hwt_dict[k[:-1]] += size
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
print("Bucket samples:")
pprint(num_dict)
print("Bucket samples by HxWxT:")
pprint(num_hwt_dict)
print("Bucket samples by aspect ratio:")
pprint(num_aspect_dict)

def state_dict(self) -> dict:
# users must ensure bucket config is the same
Expand Down Expand Up @@ -175,9 +174,7 @@ def __iter__(self) -> Iterator[List[int]]:
cur_sample_indices = [sample_idx]
else:
cur_sample_indices.append(sample_idx)
if len(cur_sample_indices) > 0 and (
not self.drop_last or len(cur_sample_indices) == cur_batch_size
):
if len(cur_sample_indices) > 0 and (not self.drop_last or len(cur_sample_indices) == cur_batch_size):
yield cur_sample_indices

def state_dict(self) -> dict:
Expand Down
6 changes: 2 additions & 4 deletions opensora/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import numbers

import numpy as np
import torch
import torchvision
Expand Down Expand Up @@ -146,8 +144,8 @@ def center_crop_arr(pil_image, image_size):


def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, int(w * rh)
Expand Down
6 changes: 1 addition & 5 deletions opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
from einops import rearrange
from timm.models.vision_transformer import Mlp

from opensora.acceleration.communications import (
all_to_all,
split_forward_gather_backward,
)
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group

approx_gelu = lambda: nn.GELU(approximate="tanh")
Expand Down Expand Up @@ -568,7 +565,6 @@ def __init__(
self.register_buffer(
"y_embedding",
torch.randn(token_num, in_channels) / in_channels**0.5,
persistent=False,
)
self.uncond_prob = uncond_prob

Expand Down
2 changes: 0 additions & 2 deletions opensora/models/text_encoder/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@


import html
import os
import re
import urllib.parse as ul

import ftfy
import torch
from bs4 import BeautifulSoup
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, T5EncoderModel

from opensora.registry import MODELS
Expand Down
8 changes: 6 additions & 2 deletions opensora/models/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def decode(self, x):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size

Expand Down Expand Up @@ -87,7 +89,9 @@ def decode(self, x):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size

Expand Down
Loading

0 comments on commit 9da2087

Please sign in to comment.