forked from Kaggle/docker-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_datasets.py
148 lines (126 loc) · 6.49 KB
/
test_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import os
import threading
import unittest
from http.server import BaseHTTPRequestHandler, HTTPServer
from test.support import EnvironmentVarGuard
from urllib.parse import urlparse
from kaggle_web_client import (KaggleWebClient,
_KAGGLE_URL_BASE_ENV_VAR_NAME,
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME,
_KAGGLE_IAP_TOKEN_ENV_VAR_NAME,
CredentialError, BackendError)
from kaggle_datasets import KaggleDatasets, _KAGGLE_TPU_NAME_ENV_VAR_NAME
_TEST_JWT = 'test-secrets-key'
_TEST_IAP = 'IAP_TOKEN'
_TPU_GCS_BUCKET = 'gs://kds-tpu-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
_AUTOML_GCS_BUCKET = 'gs://kds-automl-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
class GcsDatasetsHTTPHandler(BaseHTTPRequestHandler):
def set_request(self):
raise NotImplementedError()
def get_response(self):
raise NotImplementedError()
def do_HEAD(s):
s.send_response(200)
def do_POST(s):
s.set_request()
s.send_response(200)
s.send_header("Content-type", "application/json")
s.end_headers()
s.wfile.write(json.dumps(s.get_response()).encode("utf-8"))
class TestDatasets(unittest.TestCase):
SERVER_ADDRESS = urlparse(os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME, default="http://127.0.0.1:8001"))
def _test_client(self, client_func, expected_path, expected_body, is_tpu=True, success=True, iap_token=False):
_request = {}
class GetGcsPathHandler(GcsDatasetsHTTPHandler):
def set_request(self):
_request['path'] = self.path
content_len = int(self.headers.get('Content-Length'))
_request['body'] = json.loads(self.rfile.read(content_len))
_request['headers'] = self.headers
def get_response(self):
if success:
gcs_path = _TPU_GCS_BUCKET if is_tpu else _AUTOML_GCS_BUCKET
return {'result': {
'destinationBucket': gcs_path,
'destinationPath': None}, 'wasSuccessful': "true"}
else:
return {'wasSuccessful': "false"}
env = EnvironmentVarGuard()
env.set(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME, _TEST_JWT)
if is_tpu:
env.set(_KAGGLE_TPU_NAME_ENV_VAR_NAME, 'FAKE_TPU')
if iap_token:
env.set(_KAGGLE_IAP_TOKEN_ENV_VAR_NAME, _TEST_IAP)
with env:
with HTTPServer((self.SERVER_ADDRESS.hostname, self.SERVER_ADDRESS.port), GetGcsPathHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()
try:
client_func()
finally:
httpd.shutdown()
path, headers, body = _request['path'], _request['headers'], _request['body']
self.assertEqual(
path,
expected_path,
msg="Fake server did not receive the right request from the KaggleDatasets client.")
self.assertEqual(
body,
expected_body,
msg="Fake server did not receive the right body from the KaggleDatasets client.")
self.assertIn('Content-Type', headers.keys(),
msg="Fake server did not receive a Content-Type header from the KaggleDatasets client.")
self.assertEqual('application/json', headers.get('Content-Type'),
msg="Fake server did not receive an application/json content type header from the KaggleDatasets client.")
self.assertIn('X-Kaggle-Authorization', headers.keys(),
msg="Fake server did not receive an X-Kaggle-Authorization header from the KaggleDatasets client.")
if iap_token:
self.assertEqual(f'Bearer {_TEST_IAP}', headers.get('Authorization'),
msg="Fake server did not receive an Authorization header from the KaggleDatasets client.")
else:
self.assertNotIn('Authorization', headers.keys(),
msg="Fake server received an Authorization header from the KaggleDatasets client. It shouldn't.")
self.assertEqual(f'Bearer {_TEST_JWT}', headers.get('X-Kaggle-Authorization'),
msg="Fake server did not receive the right X-Kaggle-Authorization header from the KaggleDatasets client.")
def test_no_token_fails(self):
env = EnvironmentVarGuard()
env.unset(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME)
with env:
with self.assertRaises(CredentialError):
client = KaggleDatasets()
def test_get_gcs_path_tpu_succeeds(self):
def call_get_gcs_path():
client = KaggleDatasets()
gcs_path = client.get_gcs_path()
self.assertEqual(gcs_path, _TPU_GCS_BUCKET)
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 2},
is_tpu=True)
def test_get_gcs_path_automl_succeeds(self):
def call_get_gcs_path():
client = KaggleDatasets()
gcs_path = client.get_gcs_path()
self.assertEqual(gcs_path, _AUTOML_GCS_BUCKET)
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 1},
is_tpu=False)
def test_get_gcs_path_handles_unsuccessful(self):
def call_get_gcs_path():
client = KaggleDatasets()
with self.assertRaises(BackendError):
gcs_path = client.get_gcs_path()
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 2},
is_tpu=True,
success=False)
def test_iap_token(self):
def call_get_gcs_path():
client = KaggleDatasets()
gcs_path = client.get_gcs_path()
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 1},
is_tpu=False, iap_token=True)