Skip to content

Commit

Permalink
Upgrade bigquery verison to 2.2.0 for aiplatform
Browse files Browse the repository at this point in the history
  • Loading branch information
bovard committed Mar 1, 2021
1 parent 7d96dee commit 662ec7e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 36 deletions.
5 changes: 1 addition & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,7 @@ RUN pip install --upgrade cython && \
pip install category_encoders && \
# google-cloud-automl 2.0.0 introduced incompatible API changes, need to pin to 1.0.1
pip install google-cloud-automl==1.0.1 && \
# Newer version crashes (latest = 1.14.0) when running tensorflow.
# python -c "from google.cloud import bigquery; import tensorflow". This flow is common because bigquery is imported in kaggle_gcp.py
# which is loaded at startup.
pip install google-cloud-bigquery==1.12.1 && \
pip install google-cloud-bigquery==2.2.0 && \
pip install google-cloud-storage && \
pip install google-cloud-translate==3.* && \
pip install google-cloud-language==2.* && \
Expand Down
9 changes: 4 additions & 5 deletions patches/kaggle_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def __init__(self, parentCredential=None, quota_project_id=None):
class _DataProxyConnection(Connection):
"""Custom Connection class used to proxy the BigQuery client to Kaggle's data proxy."""

API_BASE_URL = os.getenv("KAGGLE_DATA_PROXY_URL")

def __init__(self, client):
super().__init__(client)
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self.extra_headers["X-KAGGLE-PROXY-DATA"] = os.getenv(
"KAGGLE_DATA_PROXY_TOKEN")

Expand All @@ -117,13 +115,14 @@ class PublicBigqueryClient(bigquery.client.Client):

def __init__(self, *args, **kwargs):
data_proxy_project = os.getenv("KAGGLE_DATA_PROXY_PROJECT")
default_api_endpoint = os.getenv("KAGGLE_DATA_PROXY_URL")
anon_credentials = credentials.AnonymousCredentials()
anon_credentials.refresh = lambda *args: None
super().__init__(
project=data_proxy_project, credentials=anon_credentials, *args, **kwargs
)
# TODO: Remove this once https://github.com/googleapis/google-cloud-python/issues/7122 is implemented.
self._connection = _DataProxyConnection(self)
self._connection = _DataProxyConnection(self, api_endpoint=default_api_endpoint)

def has_been_monkeypatched(method):
return "kaggle_gcp" in inspect.getsourcefile(method)
Expand Down
41 changes: 14 additions & 27 deletions tests/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from kaggle_gcp import KaggleKernelCredentials, PublicBigqueryClient, _DataProxyConnection, init_bigquery
import kaggle_secrets


class TestBigQuery(unittest.TestCase):

API_BASE_URL = "http://127.0.0.1:2121"
Expand Down Expand Up @@ -59,75 +58,63 @@ def do_GET(self):
def _setup_mocks(self, api_url_mock):
api_url_mock.__str__.return_value = self.API_BASE_URL

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_project_with_connected_account(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_empty_integrations(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_project_with_empty_integrations(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', '')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account_unrelated_integrations(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_project_with_connected_account_unrelated_integrations(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'GCS:ANOTHER_ONE')
with env:
client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_connected_account_default_credentials(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_project_with_connected_account_default_credentials(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
with env:
client = bigquery.Client(project='ANOTHER_PROJECT')
client = bigquery.Client(project='ANOTHER_PROJECT', client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self.assertTrue(client._connection.user_agent.startswith("kaggle-gcp-client/1.0"))
self._test_integration(client)

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_project_with_env_var_project_default_credentials(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_project_with_env_var_project_default_credentials(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
env.set('GOOGLE_CLOUD_PROJECT', 'ANOTHER_PROJECT')
with env:
client = bigquery.Client()
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)

@patch.object(Connection, 'API_BASE_URL')
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
def test_simultaneous_clients(self, mock_access_token, ApiUrlMock):
self._setup_mocks(ApiUrlMock)
def test_simultaneous_clients(self, mock_access_token):
env = EnvironmentVarGuard()
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
with env:
proxy_client = bigquery.Client()
proxy_client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
bq_client = bigquery.Client(
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(bq_client)
# Verify that proxy client is still going to proxy to ensure global Connection
# isn't being modified.
Expand All @@ -142,7 +129,7 @@ def test_no_project_with_connected_account(self):
with self.assertRaises(DefaultCredentialsError):
# TODO(vimota): Handle this case, either default to Kaggle Proxy or use some default project
# by the user or throw a custom exception.
client = bigquery.Client()
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
self._test_integration(client)

def test_magics_with_connected_account_default_credentials(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tensorflow_bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

from google.cloud import bigquery
import tensorflow as tf


class TestTensorflowBigQuery(unittest.TestCase):

# Some versions of bigquery crashed tensorflow, add this test to make sure that doesn't happen.
# python -c "from google.cloud import bigquery; import tensorflow". This flow is common because bigquery is imported in kaggle_gcp.py
# which is loaded at startup.
def test_addition(self):
result = tf.add([1, 2], [3, 4])
self.assertEqual([2], result.shape)

0 comments on commit 662ec7e

Please sign in to comment.