Skip to content

Commit

Permalink
Addresses nits from PR. Mostly naming and spaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
vimota committed Feb 1, 2019
1 parent d52c6f9 commit 5438d79
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
11 changes: 6 additions & 5 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class UserSecretsClient():

def __init__(self):
url_base_override = os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME)
self.url_base = (url_base_override or _KAGGLE_DEFAULT_URL_BASE)
self.url_base = url_base_override or _KAGGLE_DEFAULT_URL_BASE
# Follow the OAuth 2.0 Authorization standard (https://tools.ietf.org/html/rfc6750)
jwt_token = os.getenv(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME)
if jwt_token is None:
raise CredentialError(
f'A JWT Token is required to use the UserSecretsClient, but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
'A JWT Token is required to use the UserSecretsClient, '
f'but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
self.headers = {'Content-type': 'application/json',
'Authorization': 'Bearer {}'.format(jwt_token)}

Expand All @@ -38,17 +39,17 @@ def _make_get_request(self, request_body):
response_json = json.loads(response.read())
return response_json

def getUserSecret(self, secret_label: str):
def get_user_secret(self, secret_label: str):
request_body = {
'SecretLabel': secret_label
}
response_json = self._make_get_request(request_body)
if 'Secret' not in response_json:
raise BackendError(
'Unexpected response from UserSecrets service.')
'Unexpected response from the service.')
return response_json['Secret']

def getBigQueryAccessToken(self):
def get_bigquery_access_token(self):
request_body = {
'Purpose': self.BIGQUERY_PURPOSE_VALUE
}
Expand Down
29 changes: 20 additions & 9 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class UserSecretsHTTPHandler(BaseHTTPRequestHandler):

def set_request(self):
raise NotImplementedError()

def get_response(self):
raise NotImplementedError()

Expand All @@ -35,34 +35,41 @@ def do_GET(s):


class TestUserSecrets(unittest.TestCase):
server_address = urlparse(os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME))
SERVER_ADDRESS = urlparse(os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME))

def _test_client(self, client_func, expected_path, secret):
_request = {}

class AccessTokenHandler(UserSecretsHTTPHandler):

def set_request(self):
_request['path'] = self.path
_request['headers'] = self.headers

def get_response(self):
return {"Secret": secret}

env = EnvironmentVarGuard()
env.set(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME, _TEST_JWT)
with env:
with HTTPServer((self.server_address.hostname, self.server_address.port), AccessTokenHandler) as httpd:
with HTTPServer((self.SERVER_ADDRESS.hostname, self.SERVER_ADDRESS.port), AccessTokenHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()

try:
client_func()
finally:
httpd.shutdown()

path, headers = _request['path'], _request['headers']
self.assertEqual(
_request['path'], expected_path, msg="Fake server did not receive the right request from the UserSecrets client.")
path,
expected_path,
msg="Fake server did not receive the right request from the UserSecrets client.")
self.assertTrue(
any(
k for k in _request['headers'] if k == "Authorization" and _request['headers'][k] == f'Bearer {_TEST_JWT}'), msg="Authorization header was missing from the UserSecrets request.")
k for k in headers
if k == "Authorization" and headers[k] == f'Bearer {_TEST_JWT}'),
msg="Authorization header was missing from the UserSecrets request.")

def test_no_token_fails(self):
env = EnvironmentVarGuard()
Expand All @@ -73,16 +80,20 @@ def test_no_token_fails(self):

def test_get_access_token_succeeds(self):
secret = '12345'

def call_get_access_token():
client = UserSecretsClient()
secret_response = client.getBigQueryAccessToken()
secret_response = client.get_bigquery_access_token()
self.assertEqual(secret_response, secret)
self._test_client(call_get_access_token, '/requests/GetUserSecretRequest?Purpose=1', secret)
self._test_client(call_get_access_token,
'/requests/GetUserSecretRequest?Purpose=1', secret)

def test_get_user_secret_succeeds(self):
secret = '5678'

def call_get_access_token():
client = UserSecretsClient()
secret_response = client.getUserSecret('MY_SECRET')
secret_response = client.get_user_secret('MY_SECRET')
self.assertEqual(secret_response, secret)
self._test_client(call_get_access_token, '/requests/GetUserSecretRequest?SecretLabel=MY_SECRET', secret)
self._test_client(
call_get_access_token, '/requests/GetUserSecretRequest?SecretLabel=MY_SECRET', secret)

0 comments on commit 5438d79

Please sign in to comment.