@@ -6,8 +6,11 @@ use numpy::{IntoPyArray, PyArray2};
6
6
use once_cell:: sync:: Lazy ; // Import Lazy
7
7
use pyo3:: exceptions:: PyValueError ;
8
8
use pyo3:: prelude:: * ;
9
+ use pyo3:: types:: PyList ;
10
+ use pythonize:: { depythonize, pythonize} ;
9
11
use reqwest:: Client ;
10
12
use serde:: { Deserialize , Serialize } ;
13
+ use serde_json:: Value as JsonValue ; // For handling untyped JSON
11
14
use std:: sync:: atomic:: { AtomicBool , Ordering } ; // Add this
12
15
use std:: sync:: Arc ;
13
16
use std:: time:: Duration ;
@@ -513,6 +516,87 @@ impl PerformanceClient {
513
516
514
517
Python :: with_gil ( |py| Ok ( result_from_async_task?. into_py ( py) ) )
515
518
}
519
+
520
+ #[ pyo3( signature = ( url_path, payloads, max_concurrent_requests = DEFAULT_CONCURRENCY , timeout_s = DEFAULT_REQUEST_TIMEOUT_S ) ) ]
521
+ fn batch_post (
522
+ & self ,
523
+ py : Python ,
524
+ url_path : String ,
525
+ payloads : Vec < PyObject > ,
526
+ max_concurrent_requests : usize ,
527
+ timeout_s : f64 ,
528
+ ) -> PyResult < PyObject > {
529
+ if payloads. is_empty ( ) {
530
+ return Err ( PyValueError :: new_err ( "Payloads list cannot be empty" ) ) ;
531
+ }
532
+ PerformanceClient :: validate_concurrency_parameters ( max_concurrent_requests, 1 ) ?; // Batch size is effectively 1
533
+ let timeout_duration = PerformanceClient :: validate_and_get_timeout_duration ( timeout_s) ?;
534
+
535
+ // Depythonize all payloads in the current thread (GIL is held)
536
+ let mut payloads_json: Vec < JsonValue > = Vec :: with_capacity ( payloads. len ( ) ) ;
537
+ for ( idx, py_obj) in payloads. into_iter ( ) . enumerate ( ) {
538
+ // Bind PyObject to current GIL lifetime to get a Bound object for depythonize
539
+ let bound_obj = py_obj. bind ( py) ;
540
+ let json_val = depythonize ( bound_obj) . map_err ( |e| {
541
+ PyValueError :: new_err ( format ! (
542
+ "Failed to depythonize payload at index {}: {}" ,
543
+ idx, e
544
+ ) )
545
+ } ) ?;
546
+ payloads_json. push ( json_val) ;
547
+ }
548
+
549
+ let client = self . client . clone ( ) ;
550
+ let api_key = self . api_key . clone ( ) ;
551
+ let api_base = self . api_base . clone ( ) ;
552
+ let rt = Arc :: clone ( & self . runtime ) ;
553
+
554
+ // The async task now receives Vec<JsonValue> and returns Result<Vec<JsonValue>, PyErr>
555
+ let result_from_async_task: Result < Vec < JsonValue > , PyErr > = py. allow_threads ( move || {
556
+ let ( tx, rx) = std:: sync:: mpsc:: channel :: < Result < Vec < JsonValue > , PyErr > > ( ) ;
557
+ rt. spawn ( async move {
558
+ let res = process_batch_post_requests (
559
+ client,
560
+ url_path,
561
+ payloads_json, // Pass depythonized JSON values
562
+ api_key,
563
+ api_base,
564
+ max_concurrent_requests,
565
+ timeout_duration,
566
+ )
567
+ . await ;
568
+ let _ = tx. send ( res) ;
569
+ } ) ;
570
+ rx. recv ( )
571
+ . map_err ( |e| {
572
+ PyValueError :: new_err ( format ! (
573
+ "Failed to receive result from async task (channel error): {}" ,
574
+ e
575
+ ) )
576
+ } )
577
+ . and_then ( |inner_result| inner_result)
578
+ } ) ;
579
+
580
+ let response_json_values = result_from_async_task?;
581
+
582
+ // Pythonize all results in the current thread (GIL is held)
583
+ let mut results_py: Vec < PyObject > = Vec :: with_capacity ( response_json_values. len ( ) ) ;
584
+ for ( idx, json_val) in response_json_values. into_iter ( ) . enumerate ( ) {
585
+ let py_obj_bound = pythonize ( py, & json_val) . map_err ( |e| {
586
+ PyValueError :: new_err ( format ! (
587
+ "Failed to pythonize response at index {}: {}" ,
588
+ idx, e
589
+ ) )
590
+ } ) ?;
591
+ // Convert Bound<'_, PyAny> to PyObject
592
+ results_py. push ( py_obj_bound. to_object ( py) ) ;
593
+ }
594
+
595
+ // Use the updated PyList::new_bound or PyList::new as per PyO3 v0.21+
596
+ // PyList::new_bound is suitable here for an iterable of PyObjects.
597
+ let py_object_list = PyList :: new_bound ( py, & results_py) ;
598
+ Ok ( py_object_list. into ( ) )
599
+ }
516
600
}
517
601
518
602
// --- Send Single Embedding Request ---
@@ -894,6 +978,121 @@ async fn process_classify_requests(
894
978
} )
895
979
}
896
980
981
+ // --- Send Single Batch Post Request ---
982
+ // Now takes JsonValue and returns JsonValue
983
+ async fn send_single_batch_post_request (
984
+ client : Client ,
985
+ full_url : String ,
986
+ payload_json : JsonValue ,
987
+ api_key : String ,
988
+ request_timeout : Duration ,
989
+ ) -> Result < JsonValue , PyErr > {
990
+ // No depythonize here
991
+
992
+ let response = client
993
+ . post ( & full_url)
994
+ . bearer_auth ( api_key)
995
+ . json ( & payload_json)
996
+ . timeout ( request_timeout)
997
+ . send ( )
998
+ . await
999
+ . map_err ( |e| PyValueError :: new_err ( format ! ( "Request failed: {}" , e) ) ) ?;
1000
+
1001
+ let successful_response = ensure_successful_response ( response) . await ?;
1002
+
1003
+ // Get response as serde_json::Value
1004
+ let response_json_value: JsonValue = successful_response
1005
+ . json :: < JsonValue > ( )
1006
+ . await
1007
+ . map_err ( |e| PyValueError :: new_err ( format ! ( "Failed to parse response JSON: {}" , e) ) ) ?;
1008
+
1009
+ // No pythonize here, return JsonValue
1010
+ Ok ( response_json_value)
1011
+ }
1012
+
1013
+ // --- Process Batch Post Requests ---
1014
+ // Now takes Vec<JsonValue> and returns Result<Vec<JsonValue>, PyErr>
1015
+ async fn process_batch_post_requests (
1016
+ client : Client ,
1017
+ url_path : String ,
1018
+ payloads_json : Vec < JsonValue > , // Takes Vec<JsonValue>
1019
+ api_key : String ,
1020
+ api_base : String ,
1021
+ max_concurrent_requests : usize ,
1022
+ request_timeout_duration : Duration ,
1023
+ ) -> Result < Vec < JsonValue > , PyErr > {
1024
+ // Returns Vec<JsonValue>
1025
+ let semaphore = Arc :: new ( Semaphore :: new ( max_concurrent_requests) ) ;
1026
+ let mut tasks = Vec :: new ( ) ;
1027
+ let cancel_token = Arc :: new ( AtomicBool :: new ( false ) ) ;
1028
+ let total_payloads = payloads_json. len ( ) ;
1029
+
1030
+ for ( index, payload_item_json) in payloads_json. into_iter ( ) . enumerate ( ) {
1031
+ // Iterate over JsonValue
1032
+ let client_clone = client. clone ( ) ;
1033
+ let api_key_clone = api_key. clone ( ) ;
1034
+ let api_base_clone = api_base. clone ( ) ;
1035
+ let url_path_clone = url_path. clone ( ) ;
1036
+ let semaphore_clone = Arc :: clone ( & semaphore) ;
1037
+ let cancel_token_clone = Arc :: clone ( & cancel_token) ;
1038
+ let individual_request_timeout = request_timeout_duration;
1039
+
1040
+ // payload_item_json is moved into its own task
1041
+ tasks. push ( tokio:: spawn ( async move {
1042
+ let permit_guard =
1043
+ acquire_permit_or_cancel ( semaphore_clone, cancel_token_clone. clone ( ) ) . await ?;
1044
+
1045
+ let full_url = format ! (
1046
+ "{}/{}" ,
1047
+ api_base_clone. trim_end_matches( '/' ) ,
1048
+ url_path_clone. trim_start_matches( '/' )
1049
+ ) ;
1050
+
1051
+ let result = send_single_batch_post_request (
1052
+ client_clone,
1053
+ full_url,
1054
+ payload_item_json, // Pass JsonValue
1055
+ api_key_clone,
1056
+ individual_request_timeout,
1057
+ )
1058
+ . await ;
1059
+
1060
+ drop ( permit_guard) ;
1061
+
1062
+ match result {
1063
+ Ok ( response_json_value) => Ok ( ( index, response_json_value) ) , // Return with original index and JsonValue
1064
+ Err ( e) => {
1065
+ cancel_token_clone. store ( true , Ordering :: SeqCst ) ;
1066
+ Err ( e)
1067
+ }
1068
+ }
1069
+ } ) ) ;
1070
+ }
1071
+
1072
+ let task_join_results = join_all ( tasks) . await ;
1073
+ let mut indexed_results: Vec < ( usize , JsonValue ) > = Vec :: with_capacity ( total_payloads) ; // Stores JsonValue
1074
+ let mut first_error: Option < PyErr > = None ;
1075
+
1076
+ for result in task_join_results {
1077
+ // D is (usize, JsonValue)
1078
+ if let Some ( indexed_data_part) =
1079
+ process_task_outcome ( result, & mut first_error, & cancel_token)
1080
+ {
1081
+ indexed_results. push ( indexed_data_part) ;
1082
+ }
1083
+ }
1084
+
1085
+ if let Some ( err) = first_error {
1086
+ return Err ( err) ;
1087
+ }
1088
+
1089
+ indexed_results. sort_by_key ( |& ( original_index, _) | original_index) ;
1090
+
1091
+ let final_results: Vec < JsonValue > = indexed_results. into_iter ( ) . map ( |( _, val) | val) . collect ( ) ; // Collect JsonValue
1092
+
1093
+ Ok ( final_results)
1094
+ }
1095
+
897
1096
// Helper function to process task results and manage errors
898
1097
fn process_task_outcome < D > (
899
1098
task_join_result : Result < Result < D , PyErr > , JoinError > , // Removed OwnedSemaphorePermit from here
0 commit comments