-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathutils.py
64 lines (47 loc) · 2.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import timeit
import numpy as np
import datasets
from datasets.arrow_writer import ArrowWriter
from datasets.features.features import _ArrayXD
def get_duration(func):
def wrapper(*args, **kwargs):
starttime = timeit.default_timer()
_ = func(*args, **kwargs)
delta = timeit.default_timer() - starttime
return delta
wrapper.__name__ = func.__name__
return wrapper
def generate_examples(features: dict, num_examples=100, seq_shapes=None):
dummy_data = []
seq_shapes = seq_shapes or {}
for i in range(num_examples):
example = {}
for col_id, (k, v) in enumerate(features.items()):
if isinstance(v, _ArrayXD):
data = np.random.rand(*v.shape).astype(v.dtype)
elif isinstance(v, datasets.Value):
if v.dtype == "string":
data = "The small grey turtle was surprisingly fast when challenged."
else:
data = np.random.randint(10, size=1).astype(v.dtype).item()
elif isinstance(v, datasets.Sequence):
while isinstance(v, datasets.Sequence):
v = v.feature
shape = seq_shapes[k]
data = np.random.rand(*shape).astype(v.dtype)
example[k] = data
dummy_data.append((i, example))
return dummy_data
def generate_example_dataset(dataset_path, features, num_examples=100, seq_shapes=None):
dummy_data = generate_examples(features, num_examples=num_examples, seq_shapes=seq_shapes)
with ArrowWriter(features=features, path=dataset_path) as writer:
for key, record in dummy_data:
example = features.encode_example(record)
writer.write(example)
num_final_examples, num_bytes = writer.finalize()
if not num_final_examples == num_examples:
raise ValueError(
f"Error writing the dataset, wrote {num_final_examples} examples but should have written {num_examples}."
)
dataset = datasets.Dataset.from_file(filename=dataset_path, info=datasets.DatasetInfo(features=features))
return dataset