Skip to content

Commit 8e32390

Browse files
authored
Convert datasets to use kagglesdk (#636)
I have some `TODO`s that won't be done until Swagger is removed. Ignore everything in `kagglesdk`. The unit tests were not reformatted earlier. ```bash $ yapf --version yapf 0.40.2 ``` The version is the same as specified in #634, so I don't know why some indentation changed.
1 parent ded7a52 commit 8e32390

File tree

12 files changed

+1649
-1313
lines changed

12 files changed

+1649
-1313
lines changed

kaggle/api/kaggle_api_extended.py

Lines changed: 194 additions & 136 deletions
Large diffs are not rendered by default.

kaggle/models/kaggle_models_extended.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
#!/usr/bin/python
2-
#
3-
# Copyright 2024 Kaggle Inc
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing, software
12-
# distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions and
15-
# limitations under the License.
16-
1+
#!/usr/bin/python
2+
#
3+
# Copyright 2024 Kaggle Inc
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
1717
#!/usr/bin/python
1818
#
1919
# Copyright 2019 Kaggle Inc
@@ -131,9 +131,14 @@ def __repr__(self):
131131
class File(object):
132132

133133
def __init__(self, init_dict):
134-
parsed_dict = {k: parse(v) for k, v in init_dict.items()}
135-
self.__dict__.update(parsed_dict)
136-
self.size = File.get_size(self.totalBytes)
134+
try: # TODO Remove try-block
135+
parsed_dict = {k: parse(v) for k, v in init_dict.items()}
136+
self.__dict__.update(parsed_dict)
137+
self.size = File.get_size(self.totalBytes)
138+
except AttributeError:
139+
self.name = init_dict.name
140+
self.creation_date = init_dict.creation_date
141+
self.size = File.get_size(init_dict.total_bytes)
137142

138143
def __repr__(self):
139144
return self.name
@@ -181,13 +186,18 @@ def __repr__(self):
181186
class ListFilesResult(object):
182187

183188
def __init__(self, init_dict):
184-
self.error_message = init_dict['errorMessage']
185-
files = init_dict['datasetFiles']
189+
try: # TODO Remove try-block
190+
self.error_message = init_dict['errorMessage']
191+
files = init_dict['datasetFiles']
192+
token = init_dict['nextPageToken']
193+
except TypeError:
194+
self.error_message = init_dict.error_message
195+
files = init_dict.dataset_files
196+
token = init_dict.next_page_token
186197
if files:
187198
self.files = [File(f) for f in files]
188199
else:
189200
self.files = {}
190-
token = init_dict['nextPageToken']
191201
if token:
192202
self.nextPageToken = token
193203
else:

kagglesdk/datasets/types/dataset_api_service.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def endpoint(self):
176176
def method():
177177
return 'POST'
178178

179+
@staticmethod
180+
def body_fields():
181+
return '*'
182+
179183
class ApiCreateDatasetResponse(KaggleObject):
180184
r"""
181185
Attributes:
@@ -310,6 +314,10 @@ def endpoint(self):
310314
def method():
311315
return 'POST'
312316

317+
@staticmethod
318+
def body_fields():
319+
return 'body'
320+
313321
class ApiCreateDatasetVersionRequest(KaggleObject):
314322
r"""
315323
Attributes:
@@ -373,6 +381,10 @@ def endpoint(self):
373381
def method():
374382
return 'POST'
375383

384+
@staticmethod
385+
def body_fields():
386+
return 'body'
387+
376388
class ApiCreateDatasetVersionRequestBody(KaggleObject):
377389
r"""
378390
Attributes:
@@ -2080,6 +2092,10 @@ def endpoint(self):
20802092
def method():
20812093
return 'POST'
20822094

2095+
@staticmethod
2096+
def body_fields():
2097+
return 'settings'
2098+
20832099
class ApiUpdateDatasetMetadataResponse(KaggleObject):
20842100
r"""
20852101
Attributes:

kagglesdk/kaggle_http_client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def _get_apikey_creds():
3434
api_key = api_key_data['key']
3535
return username, api_key
3636

37+
def clean_data(data):
38+
if isinstance(data, dict):
39+
return {k: clean_data(v) for k, v in data.items() if v is not None}
40+
if isinstance(data, list):
41+
return [clean_data(v) for v in data if v is not None]
42+
if data is True:
43+
return 'true'
44+
if data is False:
45+
return 'false'
46+
return data
3747

3848
class KaggleHttpClient(object):
3949
_xsrf_cookie_name = 'XSRF-TOKEN'
@@ -75,6 +85,19 @@ def _prepare_request(self, service_name: str, request_name: str, request: Kaggle
7585
'Accept': 'application/json',
7686
'Content-Type': 'text/plain',
7787
})
88+
elif method == 'POST':
89+
self._session.headers.update({
90+
'Accept': 'application/json',
91+
'Content-Type': 'application/json',
92+
})
93+
if isinstance(data, dict):
94+
fields = request.body_fields()
95+
if fields is not None:
96+
if fields != '*':
97+
data = data[fields]
98+
data = clean_data(data)
99+
data = data.__str__().replace("'", '"')
100+
# TODO Remove quotes from numbers.
78101
http_request = requests.Request(
79102
method=method,
80103
url=request_url,

kagglesdk/kaggle_object.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ class KaggleObject(object):
205205
def endpoint(self):
206206
raise 'Error: endpoint must be defined by the request object'
207207

208+
@staticmethod
209+
def body_fields():
210+
return None
211+
208212
@classmethod
209213
def prepare_from(cls, http_response):
210214
return cls.from_json(http_response.text)
@@ -229,7 +233,7 @@ def to_dict(self, ignore_defaults=True):
229233
@staticmethod
230234
def to_field_map(self, ignore_defaults=True):
231235
kv_pairs = [(field.field_name, field.get_as_dict_item(self, ignore_defaults)) for field in self._fields]
232-
return {k: v for (k, v) in kv_pairs if not ignore_defaults or v is not None}
236+
return {k: str(v) for (k, v) in kv_pairs if not ignore_defaults or v is not None}
233237

234238
@classmethod
235239
def from_dict(cls, json_dict):

kagglesdk/kernels/types/kernels_api_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,10 @@ def endpoint(self):
13651365
def method():
13661366
return 'POST'
13671367

1368+
@staticmethod
1369+
def body_fields():
1370+
return '*'
1371+
13681372
class ApiSaveKernelResponse(KaggleObject):
13691373
r"""
13701374
Attributes:

kagglesdk/models/types/model_api_service.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kagglesdk.datasets.types.dataset_api_service import ApiCategory, ApiDatasetNewFile, ApiUploadDirectoryInfo
44
from kagglesdk.kaggle_object import *
55
from kagglesdk.models.types.model_enums import ListModelsOrderBy, ModelFramework, ModelInstanceType
6-
from kagglesdk.models.types.model_types import BaseModelInstanceInformation
6+
from kagglesdk.models.types.model_types import BaseModelInstanceInformation, ModelLink
77
from typing import Optional, List
88

99
class ApiCreateModelInstanceRequest(KaggleObject):
@@ -69,6 +69,10 @@ def endpoint(self):
6969
def method():
7070
return 'POST'
7171

72+
@staticmethod
73+
def body_fields():
74+
return 'body'
75+
7276
class ApiCreateModelInstanceRequestBody(KaggleObject):
7377
r"""
7478
Attributes:
@@ -357,6 +361,10 @@ def endpoint(self):
357361
def method():
358362
return 'POST'
359363

364+
@staticmethod
365+
def body_fields():
366+
return 'body'
367+
360368
class ApiCreateModelInstanceVersionRequestBody(KaggleObject):
361369
r"""
362370
Attributes:
@@ -553,6 +561,10 @@ def endpoint(self):
553561
def method():
554562
return 'POST'
555563

564+
@staticmethod
565+
def body_fields():
566+
return '*'
567+
556568
class ApiCreateModelResponse(KaggleObject):
557569
r"""
558570
Attributes:
@@ -1280,6 +1292,8 @@ class ApiListModelsRequest(KaggleObject):
12801292
Page size.
12811293
page_token (str)
12821294
Page token used for pagination.
1295+
only_vertex_models (bool)
1296+
Only list models that have Vertex URLs
12831297
"""
12841298

12851299
def __init__(self):
@@ -1288,6 +1302,7 @@ def __init__(self):
12881302
self._owner = None
12891303
self._page_size = None
12901304
self._page_token = None
1305+
self._only_vertex_models = None
12911306
self._freeze()
12921307

12931308
@property
@@ -1363,6 +1378,20 @@ def page_token(self, page_token: str):
13631378
raise TypeError('page_token must be of type str')
13641379
self._page_token = page_token
13651380

1381+
@property
1382+
def only_vertex_models(self) -> bool:
1383+
"""Only list models that have Vertex URLs"""
1384+
return self._only_vertex_models or False
1385+
1386+
@only_vertex_models.setter
1387+
def only_vertex_models(self, only_vertex_models: bool):
1388+
if only_vertex_models is None:
1389+
del self.only_vertex_models
1390+
return
1391+
if not isinstance(only_vertex_models, bool):
1392+
raise TypeError('only_vertex_models must be of type bool')
1393+
self._only_vertex_models = only_vertex_models
1394+
13661395

13671396
def endpoint(self):
13681397
path = '/api/v1/models/list'
@@ -1441,6 +1470,7 @@ class ApiModel(KaggleObject):
14411470
publish_time (datetime)
14421471
provenance_sources (str)
14431472
url (str)
1473+
model_version_links (ModelLink)
14441474
"""
14451475

14461476
def __init__(self):
@@ -1457,6 +1487,7 @@ def __init__(self):
14571487
self._publish_time = None
14581488
self._provenance_sources = ""
14591489
self._url = ""
1490+
self._model_version_links = []
14601491
self._freeze()
14611492

14621493
@property
@@ -1633,6 +1664,21 @@ def url(self, url: str):
16331664
raise TypeError('url must be of type str')
16341665
self._url = url
16351666

1667+
@property
1668+
def model_version_links(self) -> Optional[List[Optional['ModelLink']]]:
1669+
return self._model_version_links
1670+
1671+
@model_version_links.setter
1672+
def model_version_links(self, model_version_links: Optional[List[Optional['ModelLink']]]):
1673+
if model_version_links is None:
1674+
del self.model_version_links
1675+
return
1676+
if not isinstance(model_version_links, list):
1677+
raise TypeError('model_version_links must be of type list')
1678+
if not all([isinstance(t, ModelLink) for t in model_version_links]):
1679+
raise TypeError('model_version_links must contain only items of type ModelLink')
1680+
self._model_version_links = model_version_links
1681+
16361682

16371683
class ApiModelFile(KaggleObject):
16381684
r"""
@@ -2139,6 +2185,10 @@ def endpoint(self):
21392185
def method():
21402186
return 'POST'
21412187

2188+
@staticmethod
2189+
def body_fields():
2190+
return '*'
2191+
21422192
class ApiUpdateModelRequest(KaggleObject):
21432193
r"""
21442194
Attributes:
@@ -2292,6 +2342,10 @@ def endpoint(self):
22922342
def method():
22932343
return 'POST'
22942344

2345+
@staticmethod
2346+
def body_fields():
2347+
return '*'
2348+
22952349
class ApiUpdateModelResponse(KaggleObject):
22962350
r"""
22972351
Attributes:
@@ -2587,6 +2641,7 @@ def create_url(self, create_url: str):
25872641
FieldMetadata("owner", "owner", "_owner", str, None, PredefinedSerializer(), optional=True),
25882642
FieldMetadata("pageSize", "page_size", "_page_size", int, None, PredefinedSerializer(), optional=True),
25892643
FieldMetadata("pageToken", "page_token", "_page_token", str, None, PredefinedSerializer(), optional=True),
2644+
FieldMetadata("onlyVertexModels", "only_vertex_models", "_only_vertex_models", bool, None, PredefinedSerializer(), optional=True),
25902645
]
25912646

25922647
ApiListModelsResponse._fields = [
@@ -2609,6 +2664,7 @@ def create_url(self, create_url: str):
26092664
FieldMetadata("publishTime", "publish_time", "_publish_time", datetime, None, DateTimeSerializer()),
26102665
FieldMetadata("provenanceSources", "provenance_sources", "_provenance_sources", str, "", PredefinedSerializer()),
26112666
FieldMetadata("url", "url", "_url", str, "", PredefinedSerializer()),
2667+
FieldMetadata("modelVersionLinks", "model_version_links", "_model_version_links", ModelLink, [], ListSerializer(KaggleObjectSerializer())),
26122668
]
26132669

26142670
ApiModelFile._fields = [

kagglesdk/models/types/model_enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@ class ModelInstanceType(enum.Enum):
4242
MODEL_INSTANCE_TYPE_KAGGLE_VARIANT = 2
4343
MODEL_INSTANCE_TYPE_EXTERNAL_VARIANT = 3
4444

45+
class ModelVersionLinkType(enum.Enum):
46+
MODEL_VERSION_LINK_TYPE_UNSPECIFIED = 0
47+
MODEL_VERSION_LINK_TYPE_VERTEX_OPEN = 1
48+
MODEL_VERSION_LINK_TYPE_VERTEX_DEPLOY = 2
49+

0 commit comments

Comments
 (0)