forked from Kaggle/docker-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_bigquery.py
151 lines (131 loc) · 7.34 KB
/
test_bigquery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import unittest
import os
import json
from unittest.mock import patch
import threading
from test.support import EnvironmentVarGuard
from urllib.parse import urlparse
from http.server import BaseHTTPRequestHandler, HTTPServer
from google.cloud import bigquery
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.bigquery._http import Connection
from kaggle_gcp import KaggleKernelCredentials, PublicBigqueryClient, _DataProxyConnection, init_bigquery
import kaggle_secrets
class TestBigQuery(unittest.TestCase):
API_BASE_URL = "http://127.0.0.1:2121"
def _test_integration(self, client):
class HTTPHandler(BaseHTTPRequestHandler):
called = False
bearer_header_found = False
def do_HEAD(self):
self.send_response(200)
def do_GET(self):
HTTPHandler.called = True
HTTPHandler.bearer_header_found = any(
k for k in self.headers if k == "authorization" and self.headers[k] == "Bearer secret")
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
sample_dataset = {
"id": "bigqueryproject:datasetname",
"datasetReference": {
"datasetId": "datasetname",
"projectId": "bigqueryproject"
}
}
self.wfile.write(json.dumps({"kind": "bigquery#datasetList", "datasets": [sample_dataset]}).encode("utf-8"))
server_address = urlparse(self.API_BASE_URL)
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()
for dataset in client.list_datasets():
self.assertEqual(dataset.dataset_id, "datasetname")
httpd.shutdown()
self.assertTrue(
HTTPHandler.called, msg="Fake server was not called from the BQ client, but should have been.")
self.assertTrue(
HTTPHandler.bearer_header_found, msg="authorization header was missing from the BQ request.")
def _setup_mocks(self, api_url_mock):
api_url_mock.__str__.return_value = self.API_BASE_URL
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_empty_integrations(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', '')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account_unrelated_integrations(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'GCS:ANOTHER_ONE')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account_default_credentials(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
with env:
client = bigquery.Client(project='ANOTHER_PROJECT', client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self.assertTrue(client._connection.user_agent.startswith("kaggle-gcp-client/1.0"))
self._test_integration(client)
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_env_var_project_default_credentials(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
env.set('GOOGLE_CLOUD_PROJECT', 'ANOTHER_PROJECT')
with env:
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_simultaneous_clients(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
with env:
proxy_client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
bq_client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(bq_client)
# Verify that proxy client is still going to proxy to ensure global Connection
# isn't being modified.
self.assertNotEqual(type(proxy_client._connection), KaggleKernelCredentials)
self.assertEqual(type(proxy_client._connection), _DataProxyConnection)
def test_no_project_with_connected_account(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
with env:
with self.assertRaises(DefaultCredentialsError):
# TODO(vimota): Handle this case, either default to Kaggle Proxy or use some default project
# by the user or throw a custom exception.
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)
def test_magics_with_connected_account_default_credentials(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
with env:
init_bigquery()
from google.cloud.bigquery import magics
self.assertEqual(type(magics.context._credentials), KaggleKernelCredentials)
magics.context.credentials = None
def test_magics_without_connected_account(self):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
with env:
init_bigquery()
from google.cloud.bigquery import magics
self.assertIsNone(magics.context._credentials)