Skip to content

Commit

Permalink
(webdataset) fix KeyError for C@H (lucidrains#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
afiaka87 authored Sep 16, 2021
1 parent 499b4c9 commit 9a2a1b6
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,17 @@ def tokenize(s):
image_mapping = {
myimg: imagepreproc
}

def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.
if mycap not in item:
return False
if myimg not in item:
return False
return True

ds = (
wds.WebDataset(DATASET)
# .shuffle(is_shuffle) # Commented out for WebDataset as the behaviour cannot be predicted yet
.map_dict(**image_text_mapping)
.map_dict(**image_mapping)
.to_tuple(mycap, myimg)
.batched(BATCH_SIZE, partial=False) # It is good to avoid partial batches when using Distributed training
)
w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
filtered_dataset = w_dataset.select(filter_dataset)
ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
else:
ds = TextImageDataset(
args.image_text_folder,
Expand Down

0 comments on commit 9a2a1b6

Please sign in to comment.