forked from mlfoundations/open_clip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_wds.py
149 lines (119 loc) · 4.93 KB
/
test_wds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import pytest
import util_test
import collections
import tarfile
import io
from PIL import Image
from training.data import get_wds_dataset
from training.params import parse_args
from training.main import random_seed
TRAIN_NUM_SAMPLES = 10_000
RTOL = 0.2
# NOTE: we use two test tar files, which are created on the fly and saved to data/input.
# 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9
# 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4
def build_inputs(test_name):
base_input_dir, _ = util_test.get_data_dirs()
input_dir = os.path.join(base_input_dir, test_name)
os.makedirs(input_dir, exist_ok=True)
def save_tar(idx, num_samples):
filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar')
tar = tarfile.open(filename, 'w')
for sample_idx in range(num_samples):
# Image
image = Image.new('RGB', (32, 32))
info = tarfile.TarInfo(f'{sample_idx}.png')
bio = io.BytesIO()
image.save(bio, format='png')
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
# Caption
info = tarfile.TarInfo(f'{sample_idx}.txt')
bio = io.BytesIO()
bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8'))
size = bio.tell()
bio.seek(0)
info.size = size
tar.addfile(info, bio)
tar.close()
save_tar(0, 10)
save_tar(1, 5)
return input_dir
def build_params(input_shards, seed=0):
args = parse_args([])
args.train_data = input_shards
args.train_num_samples = TRAIN_NUM_SAMPLES
args.dataset_resampled = True
args.seed = seed
args.workers = 1
args.world_size = 1
args.batch_size = 1
random_seed(seed)
preprocess_img = lambda x: x
tokenizer = lambda x: [x.strip()]
return args, preprocess_img, tokenizer
def get_dataloader(input_shards):
args, preprocess_img, tokenizer = build_params(input_shards)
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
return dataloader
def test_single_source():
"""Test webdataset with a single tar file."""
input_dir = build_inputs('single_source')
input_shards = os.path.join(input_dir, 'test_data_000.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL)
def test_two_sources():
"""Test webdataset with a single two tar files."""
input_dir = build_inputs('two_sources')
input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar')
dataloader = get_dataloader(input_shards)
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_same_weights():
"""Test webdataset with a two tar files, using --train-data-weights=1::1."""
input_dir = build_inputs('two_sources_same_weights')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::1'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'
def test_two_sources_with_upsampling():
"""Test webdataset with a two tar files with upsampling."""
input_dir = build_inputs('two_sources_with_upsampling')
input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
args, preprocess_img, tokenizer = build_params(input_shards)
args.train_data_upsampling_factors = '1::2'
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
dataloader = dataset.dataloader
counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1
for key, count in counts.items():
if key.startswith('000'):
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}'
else:
assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}'