Skip to content

Commit

Permalink
handled erroneous data (hpcaitech#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored May 14, 2024
1 parent 5929cae commit f73f756
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,14 @@ cache/

# Secret files
hostfile
run.sh
gradio_cached_examples/
wandb/

<<<<<<< HEAD
# vae weights
eval/vae/flolpips/weights/
=======

# npm
node_modules/
package-lock.json
package.json
>>>>>>> upstream/main
15 changes: 6 additions & 9 deletions opensora/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,9 @@ def getitem(self, index):
return ret

def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
index, num_frames, height, width = [int(val) for val in index.split("-")]
path = self.data.iloc[index]["path"]
print(f"data {path}: {e}")
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
try:
return self.getitem(index)
except Exception:
# we return None here in case of errorneous data
# the collate function will handle it
return None
7 changes: 7 additions & 0 deletions opensora/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,10 @@ def resize_crop_to_fill(pil_image, image_size):
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i : i + th, j : j + tw])


def collate_fn_ignore_none(batch):
# we filter out the None values
# None value is returned when the get_item fails for an index
batch = [val for val in batch if val is not None]
return torch.utils.data.default_collate(batch)
4 changes: 3 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm

import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets.utils import collate_fn_ignore_none
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
Expand Down Expand Up @@ -97,6 +98,7 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_ignore_none,
)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
Expand Down

0 comments on commit f73f756

Please sign in to comment.