Skip to content

Commit c9bc1c4

Browse files
authored
Fix submit by custom handling of requests that use forms (#681)
Note: Some of the changes are due to other teams modifying proto files (`ApiDownloadKernelOutputZipRequest`). The list of requests that use forms should not be hard-coded like this, but that's going to take more work than I have time for just now. I'm guessing the changes in `__init__.py` are due to line endings. The changes to `kaggle_http_client.py` have NOT yet been added to `kapigen`.
1 parent 7f674fe commit c9bc1c4

File tree

7 files changed

+131
-17
lines changed

7 files changed

+131
-17
lines changed

kaggle/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
# coding=utf-8
18-
from __future__ import absolute_import
19-
from kaggle.api.kaggle_api_extended import KaggleApi
20-
21-
api = KaggleApi()
22-
api.authenticate()
17+
# coding=utf-8
18+
from __future__ import absolute_import
19+
from kaggle.api.kaggle_api_extended import KaggleApi
20+
21+
api = KaggleApi()
22+
api.authenticate()

kagglesdk/competitions/types/competition_api_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def file_name(self, file_name: str):
771771

772772

773773
def endpoint(self):
774-
path = '/api/v1/competitions/{competition_name}/submissions/url/{content_length}/{last_modified_epoch_seconds}'
774+
path = '/api/v1/competitions/submission-url'
775775
return path.format_map(self.to_field_map(self))
776776

777777

kagglesdk/kaggle_http_client.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import binascii
2+
import codecs
13
import json
24
import os
35
import urllib.parse
6+
from io import BytesIO
47

58
import requests
9+
from urllib3.fields import RequestField
10+
611
from kagglesdk.kaggle_env import get_endpoint, get_env, KaggleEnv
712
from kagglesdk.kaggle_object import KaggleObject
813
from typing import Type
@@ -12,6 +17,8 @@
1217
# auth handling. The new client requires KAGGLE_API_TOKEN, so it is not
1318
# currently usable by the CLI.
1419

20+
# TODO: Extend kapigen to add a boolean to these requests indicating that they use forms.
21+
REQUESTS_REQUIRING_FORMS = ['ApiUploadDatasetFileRequest', 'ApiCreateSubmissionRequest', 'ApiStartSubmissionUploadRequest', 'ApiUploadModelFileRequest']
1522

1623
def _headers_to_str(headers):
1724
return '\n'.join(f'{k}: {v}' for k, v in headers.items())
@@ -117,18 +124,22 @@ def _prepare_request(self, service_name: str, request_name: str,
117124
'Content-Type': 'text/plain',
118125
})
119126
elif method == 'POST':
120-
self._session.headers.update({
121-
'Accept': 'application/json',
122-
'Content-Type': 'application/json',
123-
})
124-
data = request.to_field_map(request, ignore_defaults=False)
127+
data = request.to_field_map(request, ignore_defaults=True)
125128
if isinstance(data, dict):
126129
fields = request.body_fields()
127130
if fields is not None:
128131
if fields != '*':
129132
data = data[fields]
130133
data = clean_data(data)
131-
data = json.dumps(data)
134+
if self.requires_form(request):
135+
data, content_type = self.make_form(data)
136+
else:
137+
content_type = 'application/json'
138+
data = json.dumps(data)
139+
self._session.headers.update({
140+
'Accept': 'application/json',
141+
'Content-Type': content_type,
142+
})
132143
http_request = requests.Request(
133144
method=method,
134145
url=request_url,
@@ -270,3 +281,36 @@ def _try_fill_auth(self):
270281

271282
def _get_request_url(self, request):
272283
return f'{self._endpoint}{request.endpoint()}'
284+
285+
@staticmethod
286+
def make_form(fields):
287+
body = BytesIO()
288+
boundary = binascii.hexlify(os.urandom(16)).decode()
289+
writer = codecs.lookup("utf-8")[3]
290+
291+
for field in fields.items():
292+
field = RequestField.from_tuples(*field)
293+
body.write(f"--{boundary}\r\n".encode("latin-1"))
294+
295+
writer(body).write(field.render_headers())
296+
data = field.data
297+
298+
if isinstance(data, int):
299+
data = str(data)
300+
301+
if isinstance(data, str):
302+
writer(body).write(data)
303+
else:
304+
body.write(data)
305+
306+
body.write(b"\r\n")
307+
308+
body.write(f"--{boundary}--\r\n".encode("latin-1"))
309+
310+
content_type = f"multipart/form-data; boundary={boundary}"
311+
312+
return body.getvalue(), content_type
313+
314+
@staticmethod
315+
def requires_form(request):
316+
return type(request).__name__ in REQUESTS_REQUIRING_FORMS

kagglesdk/kernels/services/kernels_api_service.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from kagglesdk.common.types.file_download import FileDownload
12
from kagglesdk.common.types.http_redirect import HttpRedirect
23
from kagglesdk.kaggle_http_client import KaggleHttpClient
3-
from kagglesdk.kernels.types.kernels_api_service import ApiDownloadKernelOutputRequest, ApiGetKernelRequest, ApiGetKernelResponse, ApiGetKernelSessionStatusRequest, ApiGetKernelSessionStatusResponse, ApiListKernelFilesRequest, ApiListKernelFilesResponse, ApiListKernelSessionOutputRequest, ApiListKernelSessionOutputResponse, ApiListKernelsRequest, ApiListKernelsResponse, ApiSaveKernelRequest, ApiSaveKernelResponse
4+
from kagglesdk.kernels.types.kernels_api_service import ApiDownloadKernelOutputRequest, ApiDownloadKernelOutputZipRequest, ApiGetKernelRequest, ApiGetKernelResponse, ApiGetKernelSessionStatusRequest, ApiGetKernelSessionStatusResponse, ApiListKernelFilesRequest, ApiListKernelFilesResponse, ApiListKernelSessionOutputRequest, ApiListKernelSessionOutputResponse, ApiListKernelsRequest, ApiListKernelsResponse, ApiSaveKernelRequest, ApiSaveKernelResponse
45

56
class KernelsApiClient(object):
67

@@ -92,3 +93,17 @@ def download_kernel_output(self, request: ApiDownloadKernelOutputRequest = None)
9293
request = ApiDownloadKernelOutputRequest()
9394

9495
return self._client.call("kernels.KernelsApiService", "ApiDownloadKernelOutput", request, HttpRedirect)
96+
97+
def download_kernel_output_zip(self, request: ApiDownloadKernelOutputZipRequest = None) -> FileDownload:
98+
r"""
99+
Meant for use by Kaggle Hub (and DownloadKernelOutput above)
100+
101+
Args:
102+
request (ApiDownloadKernelOutputZipRequest):
103+
The request object; initialized to empty instance if not specified.
104+
"""
105+
106+
if request is None:
107+
request = ApiDownloadKernelOutputZipRequest()
108+
109+
return self._client.call("kernels.KernelsApiService", "ApiDownloadKernelOutputZip", request, FileDownload)

kagglesdk/kernels/types/kernels_api_service.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,38 @@ def endpoint(self):
8585
def endpoint_path():
8686
return '/api/v1/kernels/output/download/{owner_slug}/{kernel_slug}'
8787

88+
class ApiDownloadKernelOutputZipRequest(KaggleObject):
89+
r"""
90+
Attributes:
91+
kernel_session_id (int)
92+
"""
93+
94+
def __init__(self):
95+
self._kernel_session_id = 0
96+
self._freeze()
97+
98+
@property
99+
def kernel_session_id(self) -> int:
100+
return self._kernel_session_id
101+
102+
@kernel_session_id.setter
103+
def kernel_session_id(self, kernel_session_id: int):
104+
if kernel_session_id is None:
105+
del self.kernel_session_id
106+
return
107+
if not isinstance(kernel_session_id, int):
108+
raise TypeError('kernel_session_id must be of type int')
109+
self._kernel_session_id = kernel_session_id
110+
111+
112+
def endpoint(self):
113+
path = '/api/v1/kernels/output/download_zip/{kernel_session_id}'
114+
return path.format_map(self.to_field_map(self))
115+
116+
@staticmethod
117+
def endpoint_path():
118+
return '/api/v1/kernels/output/download_zip/{kernel_session_id}'
119+
88120
class ApiGetKernelRequest(KaggleObject):
89121
r"""
90122
Attributes:
@@ -1720,6 +1752,10 @@ def creation_date(self, creation_date: str):
17201752
FieldMetadata("versionNumber", "version_number", "_version_number", int, None, PredefinedSerializer(), optional=True),
17211753
]
17221754

1755+
ApiDownloadKernelOutputZipRequest._fields = [
1756+
FieldMetadata("kernelSessionId", "kernel_session_id", "_kernel_session_id", int, 0, PredefinedSerializer()),
1757+
]
1758+
17231759
ApiGetKernelRequest._fields = [
17241760
FieldMetadata("userName", "user_name", "_user_name", str, "", PredefinedSerializer()),
17251761
FieldMetadata("kernelSlug", "kernel_slug", "_kernel_slug", str, "", PredefinedSerializer()),

kagglesdk/models/types/model_api_service.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,7 @@ class ApiModelInstance(KaggleObject):
17681768
model_instance_type (ModelInstanceType)
17691769
base_model_instance_information (BaseModelInstanceInformation)
17701770
external_base_model_url (str)
1771+
total_uncompressed_bytes (int)
17711772
"""
17721773

17731774
def __init__(self):
@@ -1786,6 +1787,7 @@ def __init__(self):
17861787
self._model_instance_type = ModelInstanceType.MODEL_INSTANCE_TYPE_UNSPECIFIED
17871788
self._base_model_instance_information = None
17881789
self._external_base_model_url = ""
1790+
self._total_uncompressed_bytes = 0
17891791
self._freeze()
17901792

17911793
@property
@@ -1985,6 +1987,19 @@ def external_base_model_url(self, external_base_model_url: str):
19851987
raise TypeError('external_base_model_url must be of type str')
19861988
self._external_base_model_url = external_base_model_url
19871989

1990+
@property
1991+
def total_uncompressed_bytes(self) -> int:
1992+
return self._total_uncompressed_bytes
1993+
1994+
@total_uncompressed_bytes.setter
1995+
def total_uncompressed_bytes(self, total_uncompressed_bytes: int):
1996+
if total_uncompressed_bytes is None:
1997+
del self.total_uncompressed_bytes
1998+
return
1999+
if not isinstance(total_uncompressed_bytes, int):
2000+
raise TypeError('total_uncompressed_bytes must be of type int')
2001+
self._total_uncompressed_bytes = total_uncompressed_bytes
2002+
19882003

19892004
class ApiUpdateModelInstanceRequest(KaggleObject):
19902005
r"""
@@ -3063,6 +3078,7 @@ def e(self, e: str):
30633078
FieldMetadata("modelInstanceType", "model_instance_type", "_model_instance_type", ModelInstanceType, ModelInstanceType.MODEL_INSTANCE_TYPE_UNSPECIFIED, EnumSerializer()),
30643079
FieldMetadata("baseModelInstanceInformation", "base_model_instance_information", "_base_model_instance_information", BaseModelInstanceInformation, None, KaggleObjectSerializer(), optional=True),
30653080
FieldMetadata("externalBaseModelUrl", "external_base_model_url", "_external_base_model_url", str, "", PredefinedSerializer()),
3081+
FieldMetadata("totalUncompressedBytes", "total_uncompressed_bytes", "_total_uncompressed_bytes", int, 0, PredefinedSerializer()),
30663082
]
30673083

30683084
ApiUpdateModelInstanceRequest._fields = [

tests/unit_tests.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,11 @@ def test_dataset_i_create_new(self):
541541
new_dataset = api.dataset_create_new(dataset_directory)
542542
self.assertIsNotNone(new_dataset)
543543
if new_dataset.error is not None:
544-
print(new_dataset.error) # This is likely to happen, and that's OK.
545-
# self.skip_create_version = True
544+
if 'already in use' in new_dataset.error:
545+
print(new_dataset.error) # This is likely to happen, and that's OK.
546+
self.skip_create_version = True
547+
else:
548+
self.fail(f"dataset_create_new failed: {new_dataset.error}")
546549
except ApiException as e:
547550
self.fail(f"dataset_create_new failed: {e}")
548551

@@ -783,7 +786,7 @@ def test_model_instance_x_delete(self):
783786
self.assertIsNotNone(inst_update_resp)
784787
if len(inst_update_resp.error):
785788
print(inst_update_resp.error)
786-
self.assertEquals(len(inst_update_resp.error), 0)
789+
self.assertEqual(len(inst_update_resp.error), 0)
787790
except ApiException as e:
788791
self.fail(f"model_instance_delete failed: {e}")
789792

0 commit comments

Comments
 (0)