Skip to content

Commit d209f25

Browse files
Use PROD if no environment is specified. (#726)
Changes implemented in PR #34468, copied here. http://b/379083750 Co-authored-by: Steve Messick <messick@google.com>
1 parent b58a85a commit d209f25

File tree

10 files changed

+398
-230
lines changed

10 files changed

+398
-230
lines changed

kagglesdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from kagglesdk.kaggle_client import KaggleClient
2-
from kagglesdk.kaggle_env import KaggleEnv
2+
from kagglesdk.kaggle_env import KaggleEnv

kagglesdk/kaggle_env.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@ class KaggleEnv(Enum):
66
LOCAL = 0 # localhost
77
STAGING = 1 # staging.kaggle.com
88
ADMIN = 2 # admin.kaggle.com
9-
QA = 3 # qa.kaggle.com
9+
QA = 3 # qa.kaggle.com
1010
# Direct prod access is not allowed to have IAP protection during testing, but we support basic auth.
1111
PROD = 4 # www.kaggle.com
1212

1313

1414
_env_to_endpoint = {
15-
KaggleEnv.LOCAL: 'http://localhost',
16-
KaggleEnv.STAGING: 'https://staging.kaggle.com',
17-
KaggleEnv.ADMIN: 'https://admin.kaggle.com',
18-
KaggleEnv.QA: 'https://qa.kaggle.com',
19-
# See the comment above in KaggleEnv enum.
20-
KaggleEnv.PROD: "https://www.kaggle.com",
15+
KaggleEnv.LOCAL: 'http://localhost',
16+
KaggleEnv.STAGING: 'https://staging.kaggle.com',
17+
KaggleEnv.ADMIN: 'https://admin.kaggle.com',
18+
KaggleEnv.QA: 'https://qa.kaggle.com',
19+
# See the comment above in KaggleEnv enum.
20+
KaggleEnv.PROD: 'https://www.kaggle.com',
2121
}
2222

2323

@@ -27,8 +27,8 @@ def get_endpoint(env: KaggleEnv):
2727

2828
def get_env():
2929
env = os.getenv('KAGGLE_API_ENVIRONMENT')
30-
if env is None:
31-
raise Exception('Must specify KaggleEnv or set KAGGLE_API_ENVIRONMENT env var')
30+
if env is None or env == 'PROD':
31+
return KaggleEnv.PROD
3232
if env == 'LOCALHOST':
3333
return KaggleEnv.LOCAL
3434
if env == 'ADMIN':
@@ -37,6 +37,4 @@ def get_env():
3737
return KaggleEnv.STAGING
3838
if env == 'QA':
3939
return KaggleEnv.QA
40-
if env == 'PROD':
41-
return KaggleEnv.PROD
4240
raise Exception(f'Unrecognized value in KAGGLE_API_ENVIRONMENT: "{env}"')

kagglesdk/kaggle_http_client.py

Lines changed: 88 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
# currently usable by the CLI.
1919

2020
# TODO: Extend kapigen to add a boolean to these requests indicating that they use forms.
21-
REQUESTS_REQUIRING_FORMS = ['ApiUploadDatasetFileRequest', 'ApiCreateSubmissionRequest', 'ApiCreateCodeSubmissionRequest', 'ApiStartSubmissionUploadRequest', 'ApiUploadModelFileRequest']
21+
REQUESTS_REQUIRING_FORMS = [
22+
'ApiUploadDatasetFileRequest',
23+
'ApiCreateSubmissionRequest',
24+
'ApiCreateCodeSubmissionRequest',
25+
'ApiStartSubmissionUploadRequest',
26+
'ApiUploadModelFileRequest',
27+
]
28+
2229

2330
def _headers_to_str(headers):
2431
return '\n'.join(f'{k}: {v}' for k, v in headers.items())
@@ -44,7 +51,9 @@ def _get_apikey_creds():
4451

4552
def clean_data(data):
4653
if isinstance(data, dict):
47-
return {to_lower_camel_case(k): clean_data(v) for k, v in data.items() if v is not None}
54+
return {
55+
to_lower_camel_case(k): clean_data(v) for k, v in data.items() if v is not None
56+
}
4857
if isinstance(data, list):
4958
return [clean_data(v) for v in data if v is not None]
5059
if data is True:
@@ -53,6 +62,7 @@ def clean_data(data):
5362
return 'false'
5463
return data
5564

65+
5666
def find_words(source, left='{', right='}'):
5767
words = []
5868
split_str = source.split(left)
@@ -64,8 +74,10 @@ def find_words(source, left='{', right='}'):
6474

6575
return words
6676

77+
6778
def to_camel_case(snake_str):
68-
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
79+
return ''.join(x.capitalize() for x in snake_str.lower().split('_'))
80+
6981

7082
def to_lower_camel_case(snake_str):
7183
# https://stackoverflow.com/questions/19053707/converting-snake-case-to-lower-camel-case-lowercamelcase
@@ -74,18 +86,21 @@ def to_lower_camel_case(snake_str):
7486
camel_string = to_camel_case(snake_str)
7587
return snake_str[0].lower() + camel_string[1:]
7688

89+
7790
class KaggleHttpClient(object):
7891
_xsrf_cookie_name = 'XSRF-TOKEN'
79-
_csrf_cookie_name = "CSRF-TOKEN"
92+
_csrf_cookie_name = 'CSRF-TOKEN'
8093
_xsrf_cookies = (_xsrf_cookie_name, _csrf_cookie_name)
8194
_xsrf_header_name = 'X-XSRF-TOKEN'
8295

83-
def __init__(self,
84-
env: KaggleEnv = None,
85-
verbose: bool = False,
86-
renew_iap_token=None,
87-
username=None,
88-
password=None):
96+
def __init__(
97+
self,
98+
env: KaggleEnv = None,
99+
verbose: bool = False,
100+
renew_iap_token=None,
101+
username=None,
102+
password=None,
103+
):
89104
self._env = env or get_env()
90105
self._signed_in = None
91106
self._endpoint = get_endpoint(self._env)
@@ -94,8 +109,13 @@ def __init__(self,
94109
self._username = username
95110
self._password = password
96111

97-
def call(self, service_name: str, request_name: str, request: KaggleObject,
98-
response_type: Type[KaggleObject]):
112+
def call(
113+
self,
114+
service_name: str,
115+
request_name: str,
116+
request: KaggleObject,
117+
response_type: Type[KaggleObject],
118+
):
99119
self._init_session()
100120
http_request = self._prepare_request(service_name, request_name, request)
101121

@@ -104,11 +124,12 @@ def call(self, service_name: str, request_name: str, request: KaggleObject,
104124
response = self._prepare_response(response_type, http_response)
105125
return response
106126

107-
def _prepare_request(self, service_name: str, request_name: str,
108-
request: KaggleObject):
127+
def _prepare_request(
128+
self, service_name: str, request_name: str, request: KaggleObject
129+
):
109130
request_url = self._get_request_url(request)
110131
method = request.method()
111-
data= ''
132+
data = ''
112133
if method == 'GET':
113134
data = request.__class__.to_dict(request, ignore_defaults=False)
114135
if request.endpoint_path():
@@ -119,10 +140,12 @@ def _prepare_request(self, service_name: str, request_name: str,
119140
if data:
120141
request_url = f'{request_url}?{urllib.parse.urlencode(clean_data(data))}'
121142
data = ''
122-
self._session.headers.update({
123-
'Accept': 'application/json',
124-
'Content-Type': 'text/plain',
125-
})
143+
self._session.headers.update(
144+
{
145+
'Accept': 'application/json',
146+
'Content-Type': 'text/plain',
147+
}
148+
)
126149
elif method == 'POST':
127150
data = request.to_field_map(request, ignore_defaults=True)
128151
if isinstance(data, dict):
@@ -136,17 +159,20 @@ def _prepare_request(self, service_name: str, request_name: str,
136159
else:
137160
content_type = 'application/json'
138161
data = json.dumps(data)
139-
self._session.headers.update({
140-
'Accept': 'application/json',
141-
'Content-Type': content_type,
142-
})
162+
self._session.headers.update(
163+
{
164+
'Accept': 'application/json',
165+
'Content-Type': content_type,
166+
}
167+
)
143168
http_request = requests.Request(
144-
method=method,
145-
url=request_url,
146-
data=data,
147-
headers=self._session.headers,
148-
# cookies=self._get_xsrf_cookies(),
149-
auth=self._session.auth)
169+
method=method,
170+
url=request_url,
171+
data=data,
172+
headers=self._session.headers,
173+
# cookies=self._get_xsrf_cookies(),
174+
auth=self._session.auth,
175+
)
150176
prepared_request = http_request.prepare()
151177
self._print_request(prepared_request)
152178
return prepared_request
@@ -164,8 +190,7 @@ def _prepare_response(self, response_type, http_response):
164190
if 'application/json' in http_response.headers['Content-Type']:
165191
resp = http_response.json()
166192
if 'code' in resp and resp['code'] >= 400:
167-
raise requests.exceptions.HTTPError(
168-
resp['message'], response=http_response)
193+
raise requests.exceptions.HTTPError(resp['message'], response=http_response)
169194
if response_type is None: # Method doesn't have a return type
170195
return None
171196
return response_type.prepare_from(http_response)
@@ -175,8 +200,8 @@ def _print_request(self, request):
175200
return
176201
self._print('---------------------Request----------------------')
177202
self._print(
178-
f'{request.method} {request.url}\n{_headers_to_str(request.headers)}\n\n{request.body}'
179-
)
203+
f'{request.method} {request.url}\n{_headers_to_str(request.headers)}\n\n{request.body}'
204+
)
180205
self._print('--------------------------------------------------')
181206

182207
def _print_response(self, response, body=True):
@@ -205,17 +230,21 @@ def _init_session(self):
205230
return self._session
206231

207232
self._session = requests.Session()
208-
self._session.headers.update({
209-
'User-Agent': 'kaggle-api/v1.7.0', # Was: V2
210-
'Content-Type': 'application/x-www-form-urlencoded', # Was: /json
211-
})
233+
self._session.headers.update(
234+
{
235+
'User-Agent': 'kaggle-api/v1.7.0', # Was: V2
236+
'Content-Type': 'application/x-www-form-urlencoded', # Was: /json
237+
}
238+
)
212239

213240
iap_token = self._get_iap_token_if_required()
214241
if iap_token is not None:
215-
self._session.headers.update({
216-
# https://cloud.google.com/iap/docs/authentication-howto#authenticating_from_proxy-authorization_header
217-
'Proxy-Authorization': f'Bearer {iap_token}',
218-
})
242+
self._session.headers.update(
243+
{
244+
# https://cloud.google.com/iap/docs/authentication-howto#authenticating_from_proxy-authorization_header
245+
'Proxy-Authorization': f'Bearer {iap_token}',
246+
}
247+
)
219248

220249
self._try_fill_auth()
221250
# self._fill_xsrf_token(iap_token) # TODO Make this align with original handler.
@@ -230,10 +259,11 @@ def _get_iap_token_if_required(self):
230259

231260
def _fill_xsrf_token(self, iap_token):
232261
initial_get_request = requests.Request(
233-
method='GET',
234-
url=self._endpoint,
235-
headers=self._session.headers,
236-
auth=self._session.auth)
262+
method='GET',
263+
url=self._endpoint,
264+
headers=self._session.headers,
265+
auth=self._session.auth,
266+
)
237267
prepared_request = initial_get_request.prepare()
238268
self._print_request(prepared_request)
239269

@@ -244,18 +274,21 @@ def _fill_xsrf_token(self, iap_token):
244274
raise requests.exceptions.HTTPError('IAP token invalid or expired')
245275
http_response.raise_for_status()
246276

247-
self._session.headers.update({
248-
KaggleHttpClient._xsrf_header_name:
249-
self._session.cookies[KaggleHttpClient._xsrf_cookie_name],
250-
})
277+
self._session.headers.update(
278+
{
279+
KaggleHttpClient._xsrf_header_name: self._session.cookies[
280+
KaggleHttpClient._xsrf_cookie_name
281+
],
282+
}
283+
)
251284

252285
class BearerAuth(requests.auth.AuthBase):
253286

254287
def __init__(self, token):
255288
self.token = token
256289

257290
def __call__(self, r):
258-
r.headers["Authorization"] = f"Bearer {self.token}"
291+
r.headers['Authorization'] = f'Bearer {self.token}'
259292
return r
260293

261294
def _try_fill_auth(self):
@@ -286,11 +319,11 @@ def _get_request_url(self, request):
286319
def make_form(fields):
287320
body = BytesIO()
288321
boundary = binascii.hexlify(os.urandom(16)).decode()
289-
writer = codecs.lookup("utf-8")[3]
322+
writer = codecs.lookup('utf-8')[3]
290323

291324
for field in fields.items():
292325
field = RequestField.from_tuples(*field)
293-
body.write(f"--{boundary}\r\n".encode("latin-1"))
326+
body.write(f'--{boundary}\r\n'.encode('latin-1'))
294327

295328
writer(body).write(field.render_headers())
296329
data = field.data
@@ -303,11 +336,11 @@ def make_form(fields):
303336
else:
304337
body.write(data)
305338

306-
body.write(b"\r\n")
339+
body.write(b'\r\n')
307340

308-
body.write(f"--{boundary}--\r\n".encode("latin-1"))
341+
body.write(f'--{boundary}--\r\n'.encode('latin-1'))
309342

310-
content_type = f"multipart/form-data; boundary={boundary}"
343+
content_type = f'multipart/form-data; boundary={boundary}'
311344

312345
return body.getvalue(), content_type
313346

0 commit comments

Comments
 (0)