Skip to content

Commit 3a3c549

Browse files
authored
Convert arg strings to enum values (#644)
Fixes #641 as a side effect.
1 parent 63244ad commit 3a3c549

File tree

5 files changed

+391
-269
lines changed

5 files changed

+391
-269
lines changed

kaggle/api/kaggle_api_extended.py

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
#!/usr/bin/python
2-
#
3-
# Copyright 2024 Kaggle Inc
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing, software
12-
# distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions and
15-
# limitations under the License.
16-
1+
#!/usr/bin/python
2+
#
3+
# Copyright 2024 Kaggle Inc
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
1717
#!/usr/bin/python
1818
#
1919
# Copyright 2019 Kaggle Inc
@@ -56,10 +56,18 @@
5656
from kaggle.configuration import Configuration
5757
from kagglesdk import KaggleClient, KaggleEnv
5858
from kagglesdk.competitions.types.competition_api_service import *
59-
from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, ApiListDatasetFilesRequest, \
60-
ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
61-
ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, ApiDatasetNewFile
62-
from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, DatasetSortBy
59+
from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, \
60+
ApiListDatasetFilesRequest, \
61+
ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, \
62+
ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
63+
ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, \
64+
ApiDatasetNewFile, ApiUpdateDatasetMetadataRequest, \
65+
ApiGetDatasetMetadataRequest
66+
from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, \
67+
DatasetSortBy, DatasetFileTypeGroup, DatasetLicenseGroup
68+
from kagglesdk.datasets.types.dataset_types import DatasetSettings, \
69+
SettingsLicense, UserRole, DatasetSettingsFile
70+
from kagglesdk.kernels.types.kernels_api_service import ApiListKernelsRequest
6371
from .kaggle_api import KaggleApi
6472
from ..api_client import ApiClient
6573
from ..models.api_blob_type import ApiBlobType
@@ -313,14 +321,17 @@ class KaggleApi(KaggleApi):
313321
]
314322

315323
# Competitions valid types
316-
valid_competition_groups = ['general', 'entered', 'inClass']
324+
valid_competition_groups = [
325+
'general', 'entered', 'community', 'hosted', 'unlaunched',
326+
'unlaunched_community'
327+
]
317328
valid_competition_categories = [
318329
'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters',
319330
'playground'
320331
]
321332
valid_competition_sort_by = [
322-
'grouped', 'prize', 'earliestDeadline', 'latestDeadline', 'numberOfTeams',
323-
'recentlyCreated'
333+
'grouped', 'best', 'prize', 'earliestDeadline', 'latestDeadline',
334+
'numberOfTeams', 'relevance', 'recentlyCreated'
324335
]
325336

326337
# Datasets valid types
@@ -709,6 +720,10 @@ def camel_to_snake(self, name):
709720
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
710721
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
711722

723+
def lookup_enum(self, enum_class, item_name):
724+
prefix = self.camel_to_snake(enum_class.__name__).upper()
725+
return enum_class[f'{prefix}_{self.camel_to_snake(item_name).upper()}']
726+
712727
## Competitions
713728

714729
def competitions_list(self,
@@ -726,20 +741,29 @@ def competitions_list(self,
726741
page: the page to return (default is 1)
727742
search: a search term to use (default is empty string)
728743
sort_by: how to sort the result, see valid_competition_sort_by for options
729-
category: category to filter result to
744+
category: category to filter result to; use 'all' to get closed competitions
730745
group: group to filter result to
731746
"""
732-
if group and group not in self.valid_competition_groups:
733-
raise ValueError('Invalid group specified. Valid options are ' +
734-
str(self.valid_competition_groups))
747+
if group:
748+
if group not in self.valid_competition_groups:
749+
raise ValueError('Invalid group specified. Valid options are ' +
750+
str(self.valid_competition_groups))
751+
if group == 'all':
752+
group = CompetitionListTab.COMPETITION_LIST_TAB_DEFAULT
753+
else:
754+
group = self.lookup_enum(CompetitionListTab, group)
735755

736-
if category and category not in self.valid_competition_categories:
737-
raise ValueError('Invalid category specified. Valid options are ' +
738-
str(self.valid_competition_categories))
756+
if category:
757+
if category not in self.valid_competition_categories:
758+
raise ValueError('Invalid category specified. Valid options are ' +
759+
str(self.valid_competition_categories))
760+
category = self.lookup_enum(HostSegment, category)
739761

740-
if sort_by and sort_by not in self.valid_competition_sort_by:
741-
raise ValueError('Invalid sort_by specified. Valid options are ' +
742-
str(self.valid_competition_sort_by))
762+
if sort_by:
763+
if sort_by not in self.valid_competition_sort_by:
764+
raise ValueError('Invalid sort_by specified. Valid options are ' +
765+
str(self.valid_competition_sort_by))
766+
sort_by = self.lookup_enum(CompetitionSortBy, sort_by)
743767

744768
with self.build_kaggle_client() as kaggle:
745769
request = ApiListCompetitionsRequest()
@@ -1199,30 +1223,36 @@ def dataset_list(self,
11991223
raise ValueError('Invalid sort by specified. Valid options are ' +
12001224
str(self.valid_dataset_sort_bys))
12011225
else:
1202-
sort_by = DatasetSortBy[f"DATASET_SORT_BY_{sort_by.upper()}"]
1226+
sort_by = self.lookup_enum(DatasetSortBy, sort_by)
12031227

12041228
if size:
12051229
raise ValueError(
12061230
'The --size parameter has been deprecated. ' +
12071231
'Please use --max-size and --min-size to filter dataset sizes.')
12081232

1209-
if file_type and file_type not in self.valid_dataset_file_types:
1210-
raise ValueError('Invalid file type specified. Valid options are ' +
1211-
str(self.valid_dataset_file_types))
1233+
if file_type:
1234+
if file_type not in self.valid_dataset_file_types:
1235+
raise ValueError('Invalid file type specified. Valid options are ' +
1236+
str(self.valid_dataset_file_types))
1237+
else:
1238+
file_type = self.lookup_enum(DatasetFileTypeGroup, file_type)
12121239

1213-
if license_name and license_name not in self.valid_dataset_license_names:
1214-
raise ValueError('Invalid license specified. Valid options are ' +
1215-
str(self.valid_dataset_license_names))
1240+
if license_name:
1241+
if license_name not in self.valid_dataset_license_names:
1242+
raise ValueError('Invalid license specified. Valid options are ' +
1243+
str(self.valid_dataset_license_names))
1244+
else:
1245+
license_name = self.lookup_enum(DatasetLicenseGroup, license_name)
12161246

12171247
if int(page) <= 0:
12181248
raise ValueError('Page number must be >= 1')
12191249

12201250
if max_size and min_size:
1221-
if (int(max_size) < int(min_size)):
1251+
if int(max_size) < int(min_size):
12221252
raise ValueError('Max Size must be max_size >= min_size')
1223-
if (max_size and int(max_size) <= 0):
1253+
if max_size and int(max_size) <= 0:
12241254
raise ValueError('Max Size must be > 0')
1225-
elif (min_size and int(min_size) < 0):
1255+
elif min_size and int(min_size) < 0:
12261256
raise ValueError('Min Size must be >= 0')
12271257

12281258
group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_PUBLIC
@@ -1315,43 +1345,57 @@ def dataset_metadata_update(self, dataset, path):
13151345
effective_path) = self.dataset_metadata_prep(dataset, path)
13161346
meta_file = self.get_dataset_metadata_file(effective_path)
13171347
with open(meta_file, 'r') as f:
1318-
metadata = json.load(f)
1348+
s = json.load(f)
1349+
metadata = json.loads(s)
13191350
updateSettingsRequest = DatasetUpdateSettingsRequest(
1320-
title=metadata['title'],
1321-
subtitle=metadata['subtitle'],
1322-
description=metadata['description'],
1323-
is_private=metadata['isPrivate'],
1324-
licenses=[License(name=l['name']) for l in metadata['licenses']],
1325-
keywords=metadata['keywords'],
1351+
title=metadata.get('title') or '',
1352+
subtitle=metadata.get('subtitle') or '',
1353+
description=metadata.get('description') or '',
1354+
is_private=metadata.get('isPrivate') or False,
1355+
licenses=[License(name=l['name']) for l in metadata['licenses']] if metadata.get('licenses') else [],
1356+
keywords=metadata.get('keywords'),
13261357
collaborators=[
1327-
Collaborator(username=c['username'], role=c['role'])
1328-
for c in metadata['collaborators']
1329-
],
1330-
data=metadata['data'])
1358+
Collaborator(username=c['username'], role=c['role'])
1359+
for c in metadata['collaborators']
1360+
] if metadata.get('collaborators') else [],
1361+
data=metadata.get('data'))
13311362
result = self.process_response(
13321363
self.metadata_post_with_http_info(owner_slug, dataset_slug,
13331364
updateSettingsRequest))
13341365
if (len(result['errors']) > 0):
13351366
[print(e['message']) for e in result['errors']]
13361367
exit(1)
13371368

1369+
def new_license(self, name):
1370+
slicense = SettingsLicense()
1371+
slicense.name = name
1372+
return slicense
1373+
1374+
def new_collaborator(self, name, role):
1375+
collab = UserRole()
1376+
collab.username = name
1377+
collab.role = role
1378+
return collab
1379+
13381380
def dataset_metadata(self, dataset, path):
13391381
(owner_slug, dataset_slug,
13401382
effective_path) = self.dataset_metadata_prep(dataset, path)
13411383

13421384
if not os.path.exists(effective_path):
13431385
os.makedirs(effective_path)
13441386

1345-
result = self.process_response(
1346-
self.metadata_get_with_http_info(owner_slug, dataset_slug))
1347-
if (result['errorMessage']):
1348-
raise Exception(result['errorMessage'])
1349-
1350-
metadata = Metadata(result['info'])
1387+
with self.build_kaggle_client() as kaggle:
1388+
request = ApiGetDatasetMetadataRequest()
1389+
request.owner_slug = owner_slug
1390+
request.dataset_slug = dataset_slug
1391+
response = kaggle.datasets.dataset_api_client.get_dataset_metadata(
1392+
request)
1393+
if response.error_message:
1394+
raise Exception(response.error_message)
13511395

13521396
meta_file = os.path.join(effective_path, self.DATASET_METADATA_FILE)
13531397
with open(meta_file, 'w') as f:
1354-
json.dump(metadata, f, indent=2, default=lambda o: o.__dict__)
1398+
json.dump(response.to_json(response.info), f, indent=2, default=lambda o: o.__dict__)
13551399

13561400
return meta_file
13571401

@@ -2109,7 +2153,7 @@ def kernels_list(self,
21092153
kernel_type=None,
21102154
output_type=None,
21112155
sort_by=None):
2112-
""" list kernels based on a set of search criteria
2156+
""" List kernels based on a set of search criteria.
21132157
21142158
Parameters
21152159
==========
@@ -2161,6 +2205,21 @@ def kernels_list(self,
21612205
if mine:
21622206
group = 'profile'
21632207

2208+
with self.build_kaggle_client() as kaggle:
2209+
request = ApiListKernelsRequest()
2210+
request.page = page
2211+
page_size = page_size
2212+
group = group # req
2213+
user = user
2214+
language = language
2215+
kernel_type = kernel_type
2216+
output_type = output_type
2217+
sort_by = sort_by #req
2218+
dataset = dataset
2219+
competition = competition
2220+
parent_kernel = parent_kernel
2221+
search = search
2222+
21642223
kernels_list_result = self.process_response(
21652224
self.kernels_list_with_http_info(
21662225
page=page,

kagglesdk/competitions/types/competition_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class CompetitionListTab(enum.Enum):
88
COMPETITION_LIST_TAB_HOSTED = 3
99
COMPETITION_LIST_TAB_UNLAUNCHED = 4
1010
COMPETITION_LIST_TAB_UNLAUNCHED_COMMUNITY = 5
11+
COMPETITION_LIST_TAB_EVERYTHING = 6
1112

1213
class CompetitionSortBy(enum.Enum):
1314
COMPETITION_SORT_BY_GROUPED = 0

kagglesdk/kaggle_http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _init_session(self):
173173
})
174174

175175
self._try_fill_auth()
176-
self._fill_xsrf_token(iap_token)
176+
# self._fill_xsrf_token(iap_token) # TODO Make this align with original handler.
177177

178178
def _get_iap_token_if_required(self):
179179
if self._env not in (KaggleEnv.STAGING, KaggleEnv.ADMIN):

0 commit comments

Comments
 (0)