forked from Kaggle/docker-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkaggle_secrets.py
81 lines (69 loc) · 3.34 KB
/
kaggle_secrets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""UserSecret client classes.
This library adds support for communicating with the UserSecrets service,
currently used for retrieving an access token for supported integrations
(ie. BigQuery).
"""
import json
import os
import urllib.request
from urllib.error import HTTPError
from typing import Tuple, Optional
from datetime import datetime, timedelta
_KAGGLE_DEFAULT_URL_BASE = "https://www.kaggle.com"
_KAGGLE_URL_BASE_ENV_VAR_NAME = "KAGGLE_URL_BASE"
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME = "KAGGLE_USER_SECRETS_TOKEN"
TIMEOUT_SECS = 40
class CredentialError(Exception):
pass
class BackendError(Exception):
pass
class UserSecretsClient():
GET_USER_SECRET_ENDPOINT = '/requests/GetUserSecretRequest'
BIGQUERY_TARGET_VALUE = 1
def __init__(self):
url_base_override = os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME)
self.url_base = url_base_override or _KAGGLE_DEFAULT_URL_BASE
# Follow the OAuth 2.0 Authorization standard (https://tools.ietf.org/html/rfc6750)
self.jwt_token = os.getenv(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME)
if self.jwt_token is None:
raise CredentialError(
'A JWT Token is required to use the UserSecretsClient, '
f'but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
self.headers = {'Content-type': 'application/json'}
def _make_post_request(self, data: dict) -> dict:
url = f'{self.url_base}{self.GET_USER_SECRET_ENDPOINT}'
request_body = dict(data)
request_body['JWE'] = self.jwt_token
req = urllib.request.Request(url, headers=self.headers, data=bytes(
json.dumps(request_body), encoding="utf-8"))
try:
with urllib.request.urlopen(req, timeout=TIMEOUT_SECS) as response:
response_json = json.loads(response.read())
if not response_json.get('wasSuccessful') or 'result' not in response_json:
raise BackendError(
'Unexpected response from the service.')
return response_json['result']
except HTTPError as e:
if e.code == 401 or e.code == 403:
raise CredentialError(f'Service responded with error code {e.code}.'
' Please ensure you have access to the resource.') from e
raise BackendError('Unexpected response from the service.') from e
def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
"""Retrieves BigQuery access token information from the UserSecrets service.
This returns the token for the current kernel as well as its expiry (abs time) if it
is present.
Example usage:
client = UserSecretsClient()
token, expiry = client.get_bigquery_access_token()
"""
request_body = {
'Target': self.BIGQUERY_TARGET_VALUE
}
response_json = self._make_post_request(request_body)
if 'secret' not in response_json:
raise BackendError(
'Unexpected response from the service.')
# Optionally return expiry if it is set.
expiresInSeconds = response_json.get('expiresInSeconds')
expiry = datetime.utcnow() + timedelta(seconds=expiresInSeconds) if expiresInSeconds else None
return response_json['secret'], expiry