Skip to content

Commit 7cdaf67

Browse files
authored
FAI-917: Add error message capturing within Python models (#137)
* Added error catching wrapper to model * linting and black * fixed typo within predict_fun message * improved error log formatting * lazy logging
1 parent 65162d2 commit 7cdaf67

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ build-backend = "setuptools.build_meta"
6464
package-dir = { "" = "src" }
6565

6666
[tool.pytest.ini_options]
67+
log_cli = true
6768
addopts = '-m="not block_plots"'
6869
markers = [
6970
"block_plots: Test plots will block execution of subsequent tests until closed"

src/trustyai/model/__init__.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# pylint: disable = unused-import, wrong-import-order
33
# pylint: disable = consider-using-f-string
44
"""General model classes"""
5+
import logging
6+
import traceback
57
import uuid as _uuid
68
from typing import List, Optional, Union, Callable
79
import pandas as pd
@@ -322,7 +324,8 @@ def __init__(self, predict_fun, **kwargs):
322324
transfer between Java and Python. If false, Arrow will be automatically used in
323325
situations where it is advantageous to do so.
324326
"""
325-
self.predict_fun = predict_fun
327+
328+
self.predict_fun = self._error_catcher(predict_fun)
326329
self.kwargs = kwargs
327330

328331
self.prediction_provider_arrow = None
@@ -332,6 +335,24 @@ def __init__(self, predict_fun, **kwargs):
332335
# set model to use non-arrow by default, as this requires no dataset information
333336
self._set_nonarrow()
334337

338+
def _error_catcher(self, predict_fun):
339+
"""Wrapper for predict function to capture errors to Python logger before the JVM dies"""
340+
341+
def wrapper(x):
342+
try:
343+
return predict_fun(x)
344+
except Exception as e:
345+
logging.error(
346+
" Fatal runtime error within the `predict_fun` supplied to trustyai.Model"
347+
)
348+
logging.error(
349+
" The error message has been captured and reproduced below:"
350+
)
351+
logging.error(" %s", traceback.format_exc())
352+
raise e
353+
354+
return wrapper
355+
335356
@property
336357
def dataframe_input(self):
337358
"""Get dataframe_input kwarg value"""
@@ -483,7 +504,7 @@ def __enter__(self):
483504
self.previous_model_state = self.model.prediction_provider
484505
self.model._set_arrow(self.paradigm_input)
485506

486-
def __exit__(self, exit_type, value, traceback):
507+
def __exit__(self, exit_type, value, tb):
487508
if self.model_is_python:
488509
self.model.prediction_provider = self.previous_model_state
489510

@@ -502,7 +523,7 @@ def __enter__(self):
502523
self.previous_model_state = self.model.prediction_provider
503524
self.model._set_nonarrow()
504525

505-
def __exit__(self, exit_type, value, traceback):
526+
def __exit__(self, exit_type, value, tb):
506527
if self.model_is_python:
507528
self.model.prediction_provider = self.previous_model_state
508529

tests/general/test_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name
22
"""Test model provider interface"""
3+
from trustyai.explainers import LimeExplainer
34

45
from common import *
56
from trustyai.model import Model, Dataset, feature
@@ -46,3 +47,14 @@ def test_cast_output_arrow():
4647
output_val = m.predictAsync(pis).get()
4748
assert len(output_val) == 25
4849

50+
51+
def test_error_model(caplog):
52+
"""test that a broken model spits out useful debugging info"""
53+
m = Model(lambda x: str(x) - str(x))
54+
try:
55+
LimeExplainer().explain(0, 0, m)
56+
except Exception:
57+
pass
58+
59+
assert "Fatal runtime error" in caplog.text
60+
assert "TypeError: unsupported operand type(s) for -: 'str' and 'str'" in caplog.text

0 commit comments

Comments
 (0)