|
1 | 1 | import pytest
|
2 | 2 | import requests
|
| 3 | +from PIL import Image |
| 4 | +import tempfile |
3 | 5 |
|
4 | 6 |
|
5 | 7 | def test_swagger():
|
@@ -29,20 +31,47 @@ def test_metadata():
|
29 | 31 | assert metadata['license'] == 'Apache v2'
|
30 | 32 |
|
31 | 33 |
|
| 34 | +def _check_predict(r): |
| 35 | + |
| 36 | + assert r.status_code == 200 |
| 37 | + response = r.json() |
| 38 | + assert response['status'] == 'ok' |
| 39 | + assert response['predictions'][0]['label_id'] == 'n02123045' |
| 40 | + assert response['predictions'][0]['label'] == 'tabby' |
| 41 | + assert response['predictions'][0]['probability'] > 0.6 |
| 42 | + |
| 43 | + |
32 | 44 | def test_predict():
|
| 45 | + |
| 46 | + formats = ['JPEG', 'PNG'] |
33 | 47 | model_endpoint = 'http://localhost:5000/model/predict'
|
34 | 48 | file_path = 'assets/cat.jpg'
|
| 49 | + jpg = Image.open(file_path) |
| 50 | + |
| 51 | + for f in formats: |
| 52 | + temp = tempfile.TemporaryFile() |
| 53 | + if f == 'PNG': |
| 54 | + jpg.convert('RGBA').save(temp, f) |
| 55 | + else: |
| 56 | + jpg.save(temp, f) |
| 57 | + temp.seek(0) |
| 58 | + file_form = {'image': (file_path, temp, 'image/{}'.format(f.lower()))} |
| 59 | + r = requests.post(url=model_endpoint, files=file_form) |
| 60 | + _check_predict(r) |
| 61 | + |
| 62 | + |
| 63 | +def test_invalid_input(): |
| 64 | + |
| 65 | + model_endpoint = 'http://localhost:5000/model/predict' |
| 66 | + file_path = 'assets/README.md' |
35 | 67 |
|
36 | 68 | with open(file_path, 'rb') as file:
|
37 | 69 | file_form = {'image': (file_path, file, 'image/jpeg')}
|
38 | 70 | r = requests.post(url=model_endpoint, files=file_form)
|
39 | 71 |
|
40 |
| - assert r.status_code == 200 |
| 72 | + assert r.status_code == 400 |
41 | 73 | response = r.json()
|
42 |
| - assert response['status'] == 'ok' |
43 |
| - assert response['predictions'][0]['label_id'] == 'n02123045' |
44 |
| - assert response['predictions'][0]['label'] == 'tabby' |
45 |
| - assert response['predictions'][0]['probability'] > 0.6 |
| 74 | + assert 'input is not a valid image' in response['message'] |
46 | 75 |
|
47 | 76 |
|
48 | 77 | if __name__ == '__main__':
|
|
0 commit comments