2
2
# pylint: disable = unused-import, wrong-import-order
3
3
# pylint: disable = consider-using-f-string
4
4
"""General model classes"""
5
+ import logging
6
+ import traceback
5
7
import uuid as _uuid
6
8
from typing import List , Optional , Union , Callable
7
9
import pandas as pd
@@ -322,7 +324,8 @@ def __init__(self, predict_fun, **kwargs):
322
324
transfer between Java and Python. If false, Arrow will be automatically used in
323
325
situations where it is advantageous to do so.
324
326
"""
325
- self .predict_fun = predict_fun
327
+
328
+ self .predict_fun = self ._error_catcher (predict_fun )
326
329
self .kwargs = kwargs
327
330
328
331
self .prediction_provider_arrow = None
@@ -332,6 +335,24 @@ def __init__(self, predict_fun, **kwargs):
332
335
# set model to use non-arrow by default, as this requires no dataset information
333
336
self ._set_nonarrow ()
334
337
338
+ def _error_catcher (self , predict_fun ):
339
+ """Wrapper for predict function to capture errors to Python logger before the JVM dies"""
340
+
341
+ def wrapper (x ):
342
+ try :
343
+ return predict_fun (x )
344
+ except Exception as e :
345
+ logging .error (
346
+ " Fatal runtime error within the `predict_fun` supplied to trustyai.Model"
347
+ )
348
+ logging .error (
349
+ " The error message has been captured and reproduced below:"
350
+ )
351
+ logging .error (" %s" , traceback .format_exc ())
352
+ raise e
353
+
354
+ return wrapper
355
+
335
356
@property
336
357
def dataframe_input (self ):
337
358
"""Get dataframe_input kwarg value"""
@@ -483,7 +504,7 @@ def __enter__(self):
483
504
self .previous_model_state = self .model .prediction_provider
484
505
self .model ._set_arrow (self .paradigm_input )
485
506
486
- def __exit__ (self , exit_type , value , traceback ):
507
+ def __exit__ (self , exit_type , value , tb ):
487
508
if self .model_is_python :
488
509
self .model .prediction_provider = self .previous_model_state
489
510
@@ -502,7 +523,7 @@ def __enter__(self):
502
523
self .previous_model_state = self .model .prediction_provider
503
524
self .model ._set_nonarrow ()
504
525
505
- def __exit__ (self , exit_type , value , traceback ):
526
+ def __exit__ (self , exit_type , value , tb ):
506
527
if self .model_is_python :
507
528
self .model .prediction_provider = self .previous_model_state
508
529
0 commit comments