-
Notifications
You must be signed in to change notification settings - Fork 57
/
download_evalsets.py
135 lines (125 loc) · 5.04 KB
/
download_evalsets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import os
import sys
import datasets
import yaml
VERBOSE = False
def main(args):
global VERBOSE
VERBOSE = args.verbose
download_datasets(args.data_dir)
def wget(src, dst, verbose=False):
vflag = "v" if VERBOSE or verbose else "nv"
os.system(f"wget -{vflag} '{src}' -O '{dst}'")
def download_datasets(data_dir):
local_urls = []
# Get list of datasets
with open("tasklist.yml") as f:
tasks = yaml.safe_load(f)
for task, task_info in tasks.items():
task_name = task_info.get("name", task)
if task.startswith("retrieval/") or task.startswith("misc/"):
# Huggingface dataset loader, download those differently
task = task.split("/", 1)[1]
try:
print()
print(f"""{f" Download '{task_name}' ":=^40s}""")
print()
datasets.load_dataset(
f"nlphuji/{task}",
split="test",
ignore_verifications=False,
cache_dir=os.path.join(data_dir, "hf_cache"),
)
except Exception as e:
print(
"Failed to download Huggingface dataset, check write permissions and Internet connection",
file=sys.stderr,
)
print(e)
continue
if task.startswith("fairness/"):
task = task.split("/", 1)[1]
# Download webdataset from Huggingface
dir_name = f"wds_{task.replace('/', '-')}_test"
source_url = f"https://huggingface.co/datasets/djghosh/{dir_name}"
target_path = os.path.join(data_dir, dir_name)
try:
print()
print(f"""{f" Download '{task_name}' ":=^40s}""")
print()
# Create directory
os.makedirs(os.path.join(target_path, "test"), exist_ok=True)
# Download metadata
wget(
os.path.join(source_url, "raw/main/classnames.txt"),
os.path.join(target_path, "classnames.txt"),
)
wget(
os.path.join(
source_url, "raw/main/zeroshot_classification_templates.txt"
),
os.path.join(target_path, "zeroshot_classification_templates.txt"),
)
wget(
os.path.join(source_url, "raw/main/test/nshards.txt"),
os.path.join(target_path, "test/nshards.txt"),
)
# Get nshards
with open(os.path.join(target_path, "test/nshards.txt")) as f:
nshards = int(f.read())
local_urls.append(os.path.join(target_path, f"test/{{0..{nshards-1}}}.tar"))
# Check and optionally download TARs
for index in range(nshards):
local_tar_path = os.path.join(target_path, f"test/{index}.tar")
if os.path.exists(local_tar_path):
# Check existing TAR
# Get expected size and checksum
with os.popen(
f"curl -s '{os.path.join(source_url, f'raw/main/test/{index}.tar')}'"
) as tar_output:
tar_info = dict(
[
line.split(maxsplit=1)
for line in tar_output.read().splitlines()
]
)
exp_checksum = tar_info["oid"].split(":")[1]
exp_size = int(tar_info["size"])
# Compute true size and checksum
with os.popen(f"sha256sum '{local_tar_path}'") as sha_output:
true_checksum = sha_output.read().split()[0]
true_size = os.path.getsize(local_tar_path)
# If equal, skip
if true_checksum == exp_checksum and true_size == exp_size:
print(f"Verified test/{index}.tar")
continue
# TAR is corrupt or does not exist, download
wget(
os.path.join(source_url, f"resolve/main/test/{index}.tar"),
local_tar_path,
verbose=True,
)
print("Successfully downloaded dataset")
except Exception as e:
print(
"Failed to download dataset, check write permissions and Internet connection",
file=sys.stderr,
)
print(e)
print()
# Print all local URLs
print("Paths to all downloaded TAR files:")
print(*local_urls, sep="\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download all data comp evaluation datasets"
)
parser.add_argument(
"data_dir", help="Root directory into which all datasets will be downloaded"
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Print verbose download status"
)
args = parser.parse_args()
sys.exit(main(args))