forked from isaacperez/tinygpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dataset.py
204 lines (161 loc) · 8.71 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import pytest
from tinygpt.dataset import TextDataset, DatasetHandler
from tinygpt.tokenizer import BPETokenizer, RegexPatterns
def test_TextDataset(tmp_path):
# Create a folder for files
folder_path = tmp_path / 'dataset'
folder_path.mkdir()
text_data = "Hello world! How are you?"
# Create a tokenizer
tokenizer = BPETokenizer(regex_pattern=RegexPatterns.GPT4)
tokenizer.train(text_corpus=text_data, vocab_size=512, verbose=False)
# Try to create an object with wrong argument
with pytest.raises(ValueError):
_ = TextDataset(data_file_path=folder_path, tokenizer=tokenizer, max_seq_length=2)
# Create the dataset file
data_path = folder_path / 'data.txt'
# Try with a file that doesn't exist
with pytest.raises(RuntimeError):
_ = TextDataset(data_file_path=data_path, tokenizer=tokenizer, max_seq_length=2)
# Populate the file with the text data
with data_path.open(mode='w') as file:
file.write(text_data)
# Try a wrong value for max_seq_length
with pytest.raises(ValueError):
_ = TextDataset(data_file_path=folder_path, tokenizer=tokenizer, max_seq_length=1)
with pytest.raises(RuntimeError):
_ = TextDataset(data_file_path=data_path, tokenizer=None, max_seq_length=2)
with pytest.raises(ValueError):
_ = TextDataset(data_file_path=data_path, tokenizer=tokenizer, max_seq_length=7)
# Create a valid dataset
dataset = TextDataset(data_file_path=data_path, tokenizer=tokenizer, max_seq_length=3)
# Test len() method
assert len(dataset) == 4
# Test getitem() method
expected_inputs_and_targets = [
("Hello world!", " world! How"),
(" world! How", "! How are"),
("! How are", " How are you"),
(" How are you", " are you?"),
]
for i in range(len(dataset)):
input_ids, target_ids = dataset[i]
expected_input_ids, expected_target_ids = expected_inputs_and_targets[i]
assert input_ids == tokenizer.encode(expected_input_ids, allowed_special="all")
assert target_ids == tokenizer.encode(expected_target_ids, allowed_special="all")
# Max seq length
dataset = TextDataset(data_file_path=data_path, tokenizer=tokenizer, max_seq_length=6)
input_ids, target_ids = dataset[0]
assert tokenizer.decode(input_ids) == 'Hello world! How are you'
assert tokenizer.decode(target_ids) == ' world! How are you?'
def test_DatasetHandler(tmp_path):
# Create a folder for files
folder_path = tmp_path / 'dataset'
folder_path.mkdir()
# Create the data file
data_path = folder_path / 'data.txt'
text_data = "Hello world! How are you?"
with data_path.open(mode='w') as file:
file.write(text_data)
# Create a tokenizer with a special token for padding
tokenizer = BPETokenizer(regex_pattern=RegexPatterns.GPT4)
tokenizer.train(text_corpus=text_data, vocab_size=512, verbose=False)
# Create a dataset from test data
dataset = TextDataset(data_file_path=data_path, tokenizer=tokenizer, max_seq_length=3)
# Test an invalid combination of parameters
with pytest.raises(ValueError):
for batch_size in [-1, len(dataset) + 1]:
dataset_handler = DatasetHandler(dataset=dataset, batch_size=batch_size, drop_last=False, shuffle=False)
with pytest.raises(RuntimeError):
dataset_handler = DatasetHandler(dataset=None, batch_size=2, drop_last=False, shuffle=False)
# Test a valid combination of parameters
num_sequences = len(dataset)
for batch_size in range(1, num_sequences + 1):
for drop_last in [True, False]:
dataset_handler = DatasetHandler(dataset=dataset, batch_size=batch_size, drop_last=drop_last, shuffle=False)
# Check the size is correct
if drop_last:
assert len(dataset_handler) == (len(dataset) // batch_size)
else:
assert len(dataset_handler) == (len(dataset) // batch_size) + int(num_sequences % batch_size > 0)
# Check batches are correct
for idx_batch, (input_batch, target_batch) in enumerate(dataset_handler):
# Check the size of the batch
assert len(input_batch) == len(target_batch)
if drop_last or idx_batch < len(dataset_handler) - 1:
assert len(input_batch) == batch_size
else:
assert len(input_batch) in {num_sequences % batch_size, batch_size}
# Check each element of the batch
for idx_element in range(len(input_batch)):
input_seq, target_seq = dataset[idx_batch * batch_size + idx_element]
assert input_seq == input_batch[idx_element]
assert target_seq == target_batch[idx_element]
# Test indexing
for index in [0, len(dataset_handler) - 1]:
input_batch, target_batch = dataset_handler[index]
# Check the size of the batch
assert len(input_batch) == len(target_batch)
if drop_last or index < len(dataset_handler) - 1:
assert len(input_batch) == batch_size
else:
assert len(input_batch) in {num_sequences % batch_size, batch_size}
# Check each element of the batch
for idx_element in range(len(input_batch)):
input_seq, target_seq = dataset[index * batch_size + idx_element]
assert input_seq == input_batch[idx_element]
assert target_seq == target_batch[idx_element]
# Test shuffle
num_sequences = len(dataset)
for batch_size in range(1, num_sequences + 1):
for drop_last in [True, False]:
# First create one without shuffling to obtain the expected batches
dataset_handler = DatasetHandler(dataset=dataset, batch_size=batch_size, drop_last=drop_last, shuffle=False)
expected_input_sequences = []
expected_target_sequences = []
for idx_batch, (input_batch, target_batch) in enumerate(dataset_handler):
for seq in input_batch:
expected_input_sequences.append(seq)
for seq in target_batch:
expected_target_sequences.append(seq)
# Now create one with shuffling
dataset_handler = DatasetHandler(dataset=dataset, batch_size=batch_size, drop_last=drop_last, shuffle=True)
# Check the size is correct
if drop_last:
assert len(dataset_handler) == (len(dataset) // batch_size)
else:
assert len(dataset_handler) == (len(dataset) // batch_size) + int(num_sequences % batch_size > 0)
# Check batches are correct
input_sequences = []
target_sequences = []
for idx_batch, (input_batch, target_batch) in enumerate(dataset_handler):
# Check the size of the batch
assert len(input_batch) == len(target_batch)
if drop_last or idx_batch < len(dataset_handler) - 1:
assert len(input_batch) == batch_size
else:
assert len(input_batch) in {num_sequences % batch_size, batch_size}
# Save the sequences
for seq in input_batch:
input_sequences.append(seq)
for seq in target_batch:
target_sequences.append(seq)
# Check we have the same number of sequences as expected
assert len(input_sequences) == len(expected_input_sequences)
assert len(target_sequences) == len(expected_target_sequences)
# Check that sequences are not repeated and that all expected sequences exist
idx_sequences_in_expected_input = set()
for seq in input_sequences:
if seq in expected_input_sequences:
idx_sequences_in_expected_input.add(expected_input_sequences.index(seq))
idx_sequences_in_expected_target = set()
for seq in target_sequences:
if seq in expected_target_sequences:
idx_sequences_in_expected_target.add(expected_target_sequences.index(seq))
num_seq_dropped = num_sequences % batch_size
if drop_last and num_seq_dropped != 0:
assert len(idx_sequences_in_expected_input) == len(idx_sequences_in_expected_target)
assert batch_size - num_seq_dropped <= len(idx_sequences_in_expected_input) < num_sequences
else:
assert len(idx_sequences_in_expected_input) == len(expected_input_sequences)
assert len(idx_sequences_in_expected_target) == len(expected_target_sequences)