Skip to content

Commit a019d1a

Browse files
authored
Add PNG support and error handling to image read (#19)
* Add PNG support and error handling to image read * Placate flake8 * Placate flake8 take 2 * Placate flake8 take 3 * Placate flake8 take 4 * Add Pillow dep for travis tests * Uncomment image conversion code
1 parent 87bbd31 commit a019d1a

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ install:
1111
- sleep 45
1212

1313
before_script:
14-
- pip install pytest requests flake8
14+
- pip install pytest requests flake8 Pillow
1515
- flake8 . --max-line-length=127
1616

1717
script:

core/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io
77
import numpy as np
88
import logging
9+
from flask import abort
910
from config import DEFAULT_MODEL_PATH, MODEL_INPUT_IMG_SIZE, MODEL_META_DATA as model_meta
1011
from maxfw.model import MAXModelWrapper
1112

@@ -26,8 +27,12 @@ def __init__(self, path=DEFAULT_MODEL_PATH):
2627
logger.info('Loaded model: {}'.format(self.model.name))
2728

2829
def _read_image(self, image_data):
29-
image = Image.open(io.BytesIO(image_data))
30-
return image
30+
try:
31+
image = Image.open(io.BytesIO(image_data)).convert('RGB')
32+
return image
33+
except IOError as e:
34+
logger.error(str(e))
35+
abort(400, "The provided input is not a valid image (PNG or JPG required).")
3136

3237
def _pre_process(self, image, target, mode='tf'):
3338
image = image.resize(target)

tests/test.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
import requests
3+
from PIL import Image
4+
import tempfile
35

46

57
def test_swagger():
@@ -29,20 +31,47 @@ def test_metadata():
2931
assert metadata['license'] == 'Apache v2'
3032

3133

34+
def _check_predict(r):
35+
36+
assert r.status_code == 200
37+
response = r.json()
38+
assert response['status'] == 'ok'
39+
assert response['predictions'][0]['label_id'] == 'n02123045'
40+
assert response['predictions'][0]['label'] == 'tabby'
41+
assert response['predictions'][0]['probability'] > 0.6
42+
43+
3244
def test_predict():
45+
46+
formats = ['JPEG', 'PNG']
3347
model_endpoint = 'http://localhost:5000/model/predict'
3448
file_path = 'assets/cat.jpg'
49+
jpg = Image.open(file_path)
50+
51+
for f in formats:
52+
temp = tempfile.TemporaryFile()
53+
if f == 'PNG':
54+
jpg.convert('RGBA').save(temp, f)
55+
else:
56+
jpg.save(temp, f)
57+
temp.seek(0)
58+
file_form = {'image': (file_path, temp, 'image/{}'.format(f.lower()))}
59+
r = requests.post(url=model_endpoint, files=file_form)
60+
_check_predict(r)
61+
62+
63+
def test_invalid_input():
64+
65+
model_endpoint = 'http://localhost:5000/model/predict'
66+
file_path = 'assets/README.md'
3567

3668
with open(file_path, 'rb') as file:
3769
file_form = {'image': (file_path, file, 'image/jpeg')}
3870
r = requests.post(url=model_endpoint, files=file_form)
3971

40-
assert r.status_code == 200
72+
assert r.status_code == 400
4173
response = r.json()
42-
assert response['status'] == 'ok'
43-
assert response['predictions'][0]['label_id'] == 'n02123045'
44-
assert response['predictions'][0]['label'] == 'tabby'
45-
assert response['predictions'][0]['probability'] > 0.6
74+
assert 'input is not a valid image' in response['message']
4675

4776

4877
if __name__ == '__main__':

0 commit comments

Comments
 (0)