Skip to content

Commit fa2cd1b

Browse files
committed
feat(VisualRecognitionV4): add support for downloading a model file
1 parent c57f248 commit fa2cd1b

File tree

2 files changed

+164
-8
lines changed

2 files changed

+164
-8
lines changed

ibm_watson/visual_recognition_v4.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,58 @@ def delete_collection(self, collection_id: str,
365365
response = self.send(request)
366366
return response
367367

368+
def get_model_file(self, collection_id: str, feature: str,
369+
model_format: str, **kwargs) -> 'DetailedResponse':
370+
"""
371+
Get a model.
372+
373+
Download a model that you can deploy to detect objects in images. The collection
374+
must include a generated model, which is indicated in the response for the
375+
collection details as `"rscnn_ready": true`. If the value is `false`, train or
376+
retrain the collection to generate the model.
377+
Currently, the model format is specific to Android apps. For more information
378+
about how to deploy the model to your app, see the [Watson Visual Recognition on
379+
Android](https://github.com/matt-ny/rscnn) project in GitHub.
380+
381+
:param str collection_id: The identifier of the collection.
382+
:param str feature: The feature for the model.
383+
:param str model_format: The format of the returned model.
384+
:param dict headers: A `dict` containing the request headers
385+
:return: A `DetailedResponse` containing the result, headers and HTTP status code.
386+
:rtype: DetailedResponse
387+
"""
388+
389+
if collection_id is None:
390+
raise ValueError('collection_id must be provided')
391+
if feature is None:
392+
raise ValueError('feature must be provided')
393+
if model_format is None:
394+
raise ValueError('model_format must be provided')
395+
396+
headers = {}
397+
if 'headers' in kwargs:
398+
headers.update(kwargs.get('headers'))
399+
sdk_headers = get_sdk_headers(service_name=self.DEFAULT_SERVICE_NAME,
400+
service_version='V4',
401+
operation_id='get_model_file')
402+
headers.update(sdk_headers)
403+
404+
params = {
405+
'version': self.version,
406+
'feature': feature,
407+
'model_format': model_format
408+
}
409+
410+
url = '/v4/collections/{0}/model'.format(
411+
*self._encode_path_vars(collection_id))
412+
request = self.prepare_request(method='GET',
413+
url=url,
414+
headers=headers,
415+
params=params)
416+
417+
response = self.send(request)
418+
return response
419+
368420
#########################
369421
# Images
370422
#########################
@@ -973,6 +1025,21 @@ class Features(Enum):
9731025
OBJECTS = 'objects'
9741026

9751027

1028+
class GetModelFileEnums(object):
1029+
1030+
class Feature(Enum):
1031+
"""
1032+
The feature for the model.
1033+
"""
1034+
OBJECTS = 'objects'
1035+
1036+
class ModelFormat(Enum):
1037+
"""
1038+
The format of the returned model.
1039+
"""
1040+
RSCNN = 'rscnn'
1041+
1042+
9761043
class GetJpegImageEnums(object):
9771044

9781045
class Size(Enum):
@@ -1732,7 +1799,7 @@ class ImageDetails():
17321799
(UTC) that the image was created.
17331800
:attr ImageSource source: The source type of the image.
17341801
:attr ImageDimensions dimensions: (optional) Height and width of an image.
1735-
:attr List[Error] errors: (optional)
1802+
:attr List[Error] errors: (optional) Details about the errors.
17361803
:attr TrainingDataObjects training_data: (optional) Training data for all
17371804
objects.
17381805
"""
@@ -1756,7 +1823,7 @@ def __init__(self,
17561823
:param datetime created: (optional) Date and time in Coordinated Universal
17571824
Time (UTC) that the image was created.
17581825
:param ImageDimensions dimensions: (optional) Height and width of an image.
1759-
:param List[Error] errors: (optional)
1826+
:param List[Error] errors: (optional) Details about the errors.
17601827
:param TrainingDataObjects training_data: (optional) Training data for all
17611828
objects.
17621829
"""
@@ -2592,13 +2659,16 @@ class ObjectTrainingStatus():
25922659
:attr bool data_changed: Whether there are changes to the training data since
25932660
the most recent training.
25942661
:attr bool latest_failed: Whether the most recent training failed.
2662+
:attr bool rscnn_ready: Whether the model can be downloaded after the training
2663+
status is `ready`.
25952664
:attr str description: Details about the training. If training is in progress,
25962665
includes information about the status. If training is not in progress, includes
25972666
a success message or information about why training failed.
25982667
"""
25992668

26002669
def __init__(self, ready: bool, in_progress: bool, data_changed: bool,
2601-
latest_failed: bool, description: str) -> None:
2670+
latest_failed: bool, rscnn_ready: bool,
2671+
description: str) -> None:
26022672
"""
26032673
Initialize a ObjectTrainingStatus object.
26042674
@@ -2608,6 +2678,8 @@ def __init__(self, ready: bool, in_progress: bool, data_changed: bool,
26082678
:param bool data_changed: Whether there are changes to the training data
26092679
since the most recent training.
26102680
:param bool latest_failed: Whether the most recent training failed.
2681+
:param bool rscnn_ready: Whether the model can be downloaded after the
2682+
training status is `ready`.
26112683
:param str description: Details about the training. If training is in
26122684
progress, includes information about the status. If training is not in
26132685
progress, includes a success message or information about why training
@@ -2617,6 +2689,7 @@ def __init__(self, ready: bool, in_progress: bool, data_changed: bool,
26172689
self.in_progress = in_progress
26182690
self.data_changed = data_changed
26192691
self.latest_failed = latest_failed
2692+
self.rscnn_ready = rscnn_ready
26202693
self.description = description
26212694

26222695
@classmethod
@@ -2625,7 +2698,7 @@ def from_dict(cls, _dict: Dict) -> 'ObjectTrainingStatus':
26252698
args = {}
26262699
valid_keys = [
26272700
'ready', 'in_progress', 'data_changed', 'latest_failed',
2628-
'description'
2701+
'rscnn_ready', 'description'
26292702
]
26302703
bad_keys = set(_dict.keys()) - set(valid_keys)
26312704
if bad_keys:
@@ -2656,6 +2729,12 @@ def from_dict(cls, _dict: Dict) -> 'ObjectTrainingStatus':
26562729
raise ValueError(
26572730
'Required property \'latest_failed\' not present in ObjectTrainingStatus JSON'
26582731
)
2732+
if 'rscnn_ready' in _dict:
2733+
args['rscnn_ready'] = _dict.get('rscnn_ready')
2734+
else:
2735+
raise ValueError(
2736+
'Required property \'rscnn_ready\' not present in ObjectTrainingStatus JSON'
2737+
)
26592738
if 'description' in _dict:
26602739
args['description'] = _dict.get('description')
26612740
else:
@@ -2680,6 +2759,8 @@ def to_dict(self) -> Dict:
26802759
_dict['data_changed'] = self.data_changed
26812760
if hasattr(self, 'latest_failed') and self.latest_failed is not None:
26822761
_dict['latest_failed'] = self.latest_failed
2762+
if hasattr(self, 'rscnn_ready') and self.rscnn_ready is not None:
2763+
_dict['rscnn_ready'] = self.rscnn_ready
26832764
if hasattr(self, 'description') and self.description is not None:
26842765
_dict['description'] = self.description
26852766
return _dict

test/unit/test_visual_recognition_v4.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,80 @@ def construct_required_body(self):
462462
return body
463463

464464

465+
#-----------------------------------------------------------------------------
466+
# Test Class for get_model_file
467+
#-----------------------------------------------------------------------------
468+
class TestGetModelFile():
469+
470+
#--------------------------------------------------------
471+
# Test 1: Send fake data and check response
472+
#--------------------------------------------------------
473+
@responses.activate
474+
def test_get_model_file_response(self):
475+
body = self.construct_full_body()
476+
response = fake_response_BinaryIO_json
477+
send_request(self, body, response)
478+
assert len(responses.calls) == 1
479+
480+
#--------------------------------------------------------
481+
# Test 2: Send only required fake data and check response
482+
#--------------------------------------------------------
483+
@responses.activate
484+
def test_get_model_file_required_response(self):
485+
# Check response with required params
486+
body = self.construct_required_body()
487+
response = fake_response_BinaryIO_json
488+
send_request(self, body, response)
489+
assert len(responses.calls) == 1
490+
491+
#--------------------------------------------------------
492+
# Test 3: Send empty data and check response
493+
#--------------------------------------------------------
494+
@responses.activate
495+
def test_get_model_file_empty(self):
496+
check_empty_required_params(self, fake_response_BinaryIO_json)
497+
check_missing_required_params(self)
498+
assert len(responses.calls) == 0
499+
500+
#-----------
501+
#- Helpers -
502+
#-----------
503+
def make_url(self, body):
504+
endpoint = '/v4/collections/{0}/model'.format(body['collection_id'])
505+
url = '{0}{1}'.format(base_url, endpoint)
506+
return url
507+
508+
def add_mock_response(self, url, response):
509+
responses.add(responses.GET,
510+
url,
511+
body=json.dumps(response),
512+
status=200,
513+
content_type='')
514+
515+
def call_service(self, body):
516+
service = VisualRecognitionV4(
517+
authenticator=NoAuthAuthenticator(),
518+
version='2019-02-11',
519+
)
520+
service.set_service_url(base_url)
521+
output = service.get_model_file(**body)
522+
return output
523+
524+
def construct_full_body(self):
525+
body = dict()
526+
body['collection_id'] = "string1"
527+
body['feature'] = "string1"
528+
body['model_format'] = "string1"
529+
return body
530+
531+
def construct_required_body(self):
532+
body = dict()
533+
body['collection_id'] = "string1"
534+
body['feature'] = "string1"
535+
body['model_format'] = "string1"
536+
return body
537+
538+
465539
# endregion
466540
##############################################################################
467541
# End of Service: Collections
@@ -1504,17 +1578,18 @@ def send_request(obj, body, response, url=None):
15041578

15051579
fake_response__json = None
15061580
fake_response_AnalyzeResponse_json = """{"images": [], "warnings": [], "trace": "fake_trace"}"""
1507-
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "description": "fake_description"}}}"""
1581+
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "rscnn_ready": false, "description": "fake_description"}}}"""
15081582
fake_response_CollectionsList_json = """{"collections": []}"""
1509-
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "description": "fake_description"}}}"""
1510-
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "description": "fake_description"}}}"""
1583+
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "rscnn_ready": false, "description": "fake_description"}}}"""
1584+
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "rscnn_ready": false, "description": "fake_description"}}}"""
1585+
fake_response_BinaryIO_json = """Contents of response byte-stream..."""
15111586
fake_response_ImageDetailsList_json = """{"images": [], "warnings": [], "trace": "fake_trace"}"""
15121587
fake_response_ImageSummaryList_json = """{"images": []}"""
15131588
fake_response_ImageDetails_json = """{"image_id": "fake_image_id", "updated": "2017-05-16T13:56:54.957Z", "created": "2017-05-16T13:56:54.957Z", "source": {"type": "fake_type", "filename": "fake_filename", "archive_filename": "fake_archive_filename", "source_url": "fake_source_url", "resolved_url": "fake_resolved_url"}, "dimensions": {"height": 6, "width": 5}, "errors": [], "training_data": {"objects": []}}"""
15141589
fake_response_BinaryIO_json = """Contents of response byte-stream..."""
15151590
fake_response_ObjectMetadataList_json = """{"object_count": 12, "objects": []}"""
15161591
fake_response_UpdateObjectMetadata_json = """{"object": "fake_object", "count": 5}"""
15171592
fake_response_ObjectMetadata_json = """{"object": "fake_object", "count": 5}"""
1518-
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "description": "fake_description"}}}"""
1593+
fake_response_Collection_json = """{"collection_id": "fake_collection_id", "name": "fake_name", "description": "fake_description", "created": "2017-05-16T13:56:54.957Z", "updated": "2017-05-16T13:56:54.957Z", "image_count": 11, "training_status": {"objects": {"ready": false, "in_progress": false, "data_changed": true, "latest_failed": false, "rscnn_ready": false, "description": "fake_description"}}}"""
15191594
fake_response_TrainingDataObjects_json = """{"objects": []}"""
15201595
fake_response_TrainingEvents_json = """{"start_time": "2017-05-16T13:56:54.957Z", "end_time": "2017-05-16T13:56:54.957Z", "completed_events": 16, "trained_images": 14, "events": []}"""

0 commit comments

Comments
 (0)