@@ -295,7 +295,7 @@ class Model:
295
295
"""
296
296
297
297
def __init__ (
298
- self , predict_fun , dataframe_input = False , output_names = None , arrow = False
298
+ self , predict_fun , dataframe_input = False , output_names = None , disable_arrow = False
299
299
):
300
300
"""
301
301
Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class.
@@ -311,39 +311,75 @@ def __init__(
311
311
output_names : List[String]:
312
312
If the model outputs a numpy array, you can specify the names of the model outputs
313
313
here.
314
- arrow : bool
315
- Whether to use Apache arrow to speed up data transfer between Java and Python.
316
- In general, set this to ``true`` whenever LIME or SHAP explanations are needed,
317
- and ``false`` for counterfactuals .
314
+ disable_arrow : bool
315
+ If true, Apache Arrow will not be used to accelerate data transfer between Java
316
+ and Python. If false, Arrow will be automatically used in situations where it is
317
+ advantageous to do so .
318
318
"""
319
- self .arrow = arrow
319
+ self .disable_arrow = disable_arrow
320
320
self .predict_fun = predict_fun
321
321
self .output_names = output_names
322
+ self .dataframe_input = dataframe_input
322
323
323
- if arrow :
324
- self .prediction_provider = None
325
- if not dataframe_input :
326
- self .prediction_provider_arrow = PredictionProviderArrow (
327
- lambda x : self ._cast_outputs_to_dataframe (predict_fun (x .values ))
328
- )
329
- else :
330
- self .prediction_provider_arrow = PredictionProviderArrow (
331
- lambda x : self ._cast_outputs_to_dataframe (predict_fun (x ))
324
+ self .prediction_provider_arrow = None
325
+ self .prediction_provider_normal = None
326
+ self .prediction_provider = None
327
+
328
+ # set model to use non-arrow by default, as this requires no dataset information
329
+ self ._set_nonarrow ()
330
+
331
+ def _set_arrow (self , paradigm_input : PredictionInput ):
332
+ """
333
+ Ready the model for arrow-based prediction communication.
334
+
335
+ Parameters
336
+ ----------
337
+ paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
338
+ All subsequent :obj:`PredictionInput`s communicated must have this schema.
339
+ """
340
+ if self .disable_arrow :
341
+ self ._set_nonarrow ()
342
+ else :
343
+ if self .prediction_provider_arrow is None :
344
+ raw_ppa = self ._get_arrow_prediction_provider ()
345
+ self .prediction_provider_arrow = raw_ppa .get_as_prediction_provider (
346
+ paradigm_input
332
347
)
348
+ self .prediction_provider = self .prediction_provider_arrow
349
+
350
+ def _set_nonarrow (self ):
351
+ """
352
+ Ready the model for non-arrow-prediction communication.
353
+ """
354
+ if self .prediction_provider_normal is None :
355
+ self .prediction_provider_normal = self ._get_nonarrow_prediction_provider ()
356
+ self .prediction_provider = self .prediction_provider_normal
357
+
358
+ def _get_arrow_prediction_provider (self ):
359
+ if not self .dataframe_input :
360
+ ppa = PredictionProviderArrow (
361
+ lambda x : self ._cast_outputs_to_dataframe (self .predict_fun (x .values ))
362
+ )
333
363
else :
334
- self .prediction_provider_arrow = None
335
- if dataframe_input :
336
- self .prediction_provider = PredictionProvider (
337
- lambda x : self ._cast_outputs (
338
- predict_fun (prediction_object_to_pandas (x ))
339
- )
364
+ ppa = PredictionProviderArrow (
365
+ lambda x : self ._cast_outputs_to_dataframe (self .predict_fun (x ))
366
+ )
367
+ return ppa
368
+
369
+ def _get_nonarrow_prediction_provider (self ):
370
+ if self .dataframe_input :
371
+ ppn = PredictionProvider (
372
+ lambda x : self ._cast_outputs (
373
+ self .predict_fun (prediction_object_to_pandas (x ))
340
374
)
341
- else :
342
- self . prediction_provider = PredictionProvider (
343
- lambda x : self . _cast_outputs (
344
- predict_fun ( prediction_object_to_numpy ( x ))
345
- )
375
+ )
376
+ else :
377
+ ppn = PredictionProvider (
378
+ lambda x : self . _cast_outputs (
379
+ self . predict_fun ( prediction_object_to_numpy ( x ) )
346
380
)
381
+ )
382
+ return ppn
347
383
348
384
def _cast_outputs (self , output_array ):
349
385
return df_to_prediction_object (
@@ -388,12 +424,8 @@ def predictAsync(self, inputs: List[PredictionInput]) -> CompletableFuture:
388
424
:obj:`CompletableFuture`
389
425
A Java :obj:`CompletableFuture` containing the model outputs.
390
426
"""
391
- if self .arrow and self .prediction_provider is None :
392
- self .prediction_provider = (
393
- self .prediction_provider_arrow .get_as_prediction_provider (inputs [0 ])
394
- )
395
- out = self .prediction_provider .predictAsync (inputs )
396
- return out
427
+
428
+ return self .prediction_provider .predictAsync (inputs )
397
429
398
430
def __call__ (self , inputs ):
399
431
"""
@@ -405,6 +437,51 @@ def __call__(self, inputs):
405
437
"""
406
438
return self .predict_fun (inputs )
407
439
440
+ class ArrowTransmission :
441
+ """
442
+ Context class to ensure all predictAsync calls within the context use arrow.
443
+
444
+ Parameters
445
+ ----------
446
+ model: The TrustyAI :obj:`Model` or PredictionProvider
447
+ paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
448
+ All subsequent :obj:`PredictionInput`s communicated must have this schema.
449
+ """
450
+
451
+ def __init__ (self , model , paradigm_input : OneInputUnionType ):
452
+ self .model = model
453
+ self .model_is_python = isinstance (model , Model )
454
+ self .paradigm_input = one_input_convert (paradigm_input )
455
+ self .previous_model_state = None
456
+
457
+ def __enter__ (self ):
458
+ if self .model_is_python :
459
+ self .previous_model_state = self .model .prediction_provider
460
+ self .model ._set_arrow (self .paradigm_input )
461
+
462
+ def __exit__ (self , exit_type , value , traceback ):
463
+ if self .model_is_python :
464
+ self .model .prediction_provider = self .previous_model_state
465
+
466
+ class NonArrowTransmission :
467
+ """
468
+ Context class to ensure all predictAsync calls within the context DO NOT use arrow.
469
+ """
470
+
471
+ def __init__ (self , model ):
472
+ self .model = model
473
+ self .model_is_python = isinstance (model , Model )
474
+ self .previous_model_state = None
475
+
476
+ def __enter__ (self ):
477
+ if self .model_is_python :
478
+ self .previous_model_state = self .model .prediction_provider
479
+ self .model ._set_nonarrow ()
480
+
481
+ def __exit__ (self , exit_type , value , traceback ):
482
+ if self .model_is_python :
483
+ self .model .prediction_provider = self .previous_model_state
484
+
408
485
409
486
@_jcustomizer .JImplementationFor ("org.kie.trustyai.explainability.model.Output" )
410
487
# pylint: disable=no-member
0 commit comments