Skip to content

Commit 9e4c83c

Browse files
authored
misc(bei-client): follow up pythonize (#1670)
* fix release pipeline bei-client * rename yml * http error * http error exception docs * another bei client change * update truss-transfer * add pythonize * add batch requests * add test client * bump truss transfer
1 parent 26dfbb5 commit 9e4c83c

File tree

8 files changed

+306
-30
lines changed

8 files changed

+306
-30
lines changed

bei-client/Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bei-client/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ futures = "0.3"
1414
once_cell = "1.21"
1515
numpy = "0.24.0" # Or a version compatible with your PyO3 version
1616
ndarray = "*"
17+
pythonize = "*"
1718

1819
[lib]
1920
name = "bei_client"

bei-client/bei_client.pyi

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,5 +340,49 @@ class PerformanceClient:
340340
"""
341341
...
342342

343+
def batch_post(
344+
self,
345+
url_path: builtins.str,
346+
payloads: builtins.list[typing.Any],
347+
max_concurrent_requests: builtins.int = 32, # DEFAULT_CONCURRENCY
348+
timeout_s: builtins.float = 3600.0, # DEFAULT_REQUEST_TIMEOUT_S
349+
) -> builtins.list[typing.Any]:
350+
"""
351+
Sends a list of generic JSON payloads to a specified URL path concurrently.
352+
353+
Each payload is sent as an individual POST request. The responses are
354+
returned as a list of Python objects, corresponding to the JSON responses
355+
from the server.
356+
357+
Args:
358+
url_path: The specific API path to post to (e.g., "/v1/custom_endpoint").
359+
payloads: A list of Python objects that are JSON-serializable.
360+
Each object will be the body of a POST request.
361+
max_concurrent_requests: Maximum number of parallel requests.
362+
timeout_s: Total timeout in seconds for the entire batch operation,
363+
also used as the timeout for each individual request.
364+
365+
Returns:
366+
A list of Python objects, where each object is the deserialized
367+
JSON response from the server for the corresponding request payload.
368+
The order of responses matches the order of input payloads.
369+
370+
Raises:
371+
ValueError: If the payloads list is empty or parameters are invalid.
372+
requests.exceptions.HTTPError: If any of the underlying HTTP requests fail.
373+
# Note: Other PyO3/Rust errors might be raised for serialization/deserialization issues.
374+
375+
Example:
376+
>>> client = PerformanceClient(api_base="https://example.api.baseten.co/sync", api_key="your_key")
377+
>>> custom_payloads = [
378+
... {"data": "request1_data", "id": 1},
379+
... {"data": "request2_data", "id": 2}
380+
... ]
381+
>>> responses = client.batch_post("/v1/process_item", custom_payloads)
382+
>>> for resp in responses:
383+
... print(resp)
384+
"""
385+
...
386+
343387
__version__: builtins.str
344388
"""The version of the bei_client library."""

bei-client/src/lib.rs

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ use numpy::{IntoPyArray, PyArray2};
66
use once_cell::sync::Lazy; // Import Lazy
77
use pyo3::exceptions::PyValueError;
88
use pyo3::prelude::*;
9+
use pyo3::types::PyList;
10+
use pythonize::{depythonize, pythonize};
911
use reqwest::Client;
1012
use serde::{Deserialize, Serialize};
13+
use serde_json::Value as JsonValue; // For handling untyped JSON
1114
use std::sync::atomic::{AtomicBool, Ordering}; // Add this
1215
use std::sync::Arc;
1316
use std::time::Duration;
@@ -513,6 +516,87 @@ impl PerformanceClient {
513516

514517
Python::with_gil(|py| Ok(result_from_async_task?.into_py(py)))
515518
}
519+
520+
#[pyo3(signature = (url_path, payloads, max_concurrent_requests = DEFAULT_CONCURRENCY, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
521+
fn batch_post(
522+
&self,
523+
py: Python,
524+
url_path: String,
525+
payloads: Vec<PyObject>,
526+
max_concurrent_requests: usize,
527+
timeout_s: f64,
528+
) -> PyResult<PyObject> {
529+
if payloads.is_empty() {
530+
return Err(PyValueError::new_err("Payloads list cannot be empty"));
531+
}
532+
PerformanceClient::validate_concurrency_parameters(max_concurrent_requests, 1)?; // Batch size is effectively 1
533+
let timeout_duration = PerformanceClient::validate_and_get_timeout_duration(timeout_s)?;
534+
535+
// Depythonize all payloads in the current thread (GIL is held)
536+
let mut payloads_json: Vec<JsonValue> = Vec::with_capacity(payloads.len());
537+
for (idx, py_obj) in payloads.into_iter().enumerate() {
538+
// Bind PyObject to current GIL lifetime to get a Bound object for depythonize
539+
let bound_obj = py_obj.bind(py);
540+
let json_val = depythonize(bound_obj).map_err(|e| {
541+
PyValueError::new_err(format!(
542+
"Failed to depythonize payload at index {}: {}",
543+
idx, e
544+
))
545+
})?;
546+
payloads_json.push(json_val);
547+
}
548+
549+
let client = self.client.clone();
550+
let api_key = self.api_key.clone();
551+
let api_base = self.api_base.clone();
552+
let rt = Arc::clone(&self.runtime);
553+
554+
// The async task now receives Vec<JsonValue> and returns Result<Vec<JsonValue>, PyErr>
555+
let result_from_async_task: Result<Vec<JsonValue>, PyErr> = py.allow_threads(move || {
556+
let (tx, rx) = std::sync::mpsc::channel::<Result<Vec<JsonValue>, PyErr>>();
557+
rt.spawn(async move {
558+
let res = process_batch_post_requests(
559+
client,
560+
url_path,
561+
payloads_json, // Pass depythonized JSON values
562+
api_key,
563+
api_base,
564+
max_concurrent_requests,
565+
timeout_duration,
566+
)
567+
.await;
568+
let _ = tx.send(res);
569+
});
570+
rx.recv()
571+
.map_err(|e| {
572+
PyValueError::new_err(format!(
573+
"Failed to receive result from async task (channel error): {}",
574+
e
575+
))
576+
})
577+
.and_then(|inner_result| inner_result)
578+
});
579+
580+
let response_json_values = result_from_async_task?;
581+
582+
// Pythonize all results in the current thread (GIL is held)
583+
let mut results_py: Vec<PyObject> = Vec::with_capacity(response_json_values.len());
584+
for (idx, json_val) in response_json_values.into_iter().enumerate() {
585+
let py_obj_bound = pythonize(py, &json_val).map_err(|e| {
586+
PyValueError::new_err(format!(
587+
"Failed to pythonize response at index {}: {}",
588+
idx, e
589+
))
590+
})?;
591+
// Convert Bound<'_, PyAny> to PyObject
592+
results_py.push(py_obj_bound.to_object(py));
593+
}
594+
595+
// Use the updated PyList::new_bound or PyList::new as per PyO3 v0.21+
596+
// PyList::new_bound is suitable here for an iterable of PyObjects.
597+
let py_object_list = PyList::new_bound(py, &results_py);
598+
Ok(py_object_list.into())
599+
}
516600
}
517601

518602
// --- Send Single Embedding Request ---
@@ -894,6 +978,121 @@ async fn process_classify_requests(
894978
})
895979
}
896980

981+
// --- Send Single Batch Post Request ---
982+
// Now takes JsonValue and returns JsonValue
983+
async fn send_single_batch_post_request(
984+
client: Client,
985+
full_url: String,
986+
payload_json: JsonValue,
987+
api_key: String,
988+
request_timeout: Duration,
989+
) -> Result<JsonValue, PyErr> {
990+
// No depythonize here
991+
992+
let response = client
993+
.post(&full_url)
994+
.bearer_auth(api_key)
995+
.json(&payload_json)
996+
.timeout(request_timeout)
997+
.send()
998+
.await
999+
.map_err(|e| PyValueError::new_err(format!("Request failed: {}", e)))?;
1000+
1001+
let successful_response = ensure_successful_response(response).await?;
1002+
1003+
// Get response as serde_json::Value
1004+
let response_json_value: JsonValue = successful_response
1005+
.json::<JsonValue>()
1006+
.await
1007+
.map_err(|e| PyValueError::new_err(format!("Failed to parse response JSON: {}", e)))?;
1008+
1009+
// No pythonize here, return JsonValue
1010+
Ok(response_json_value)
1011+
}
1012+
1013+
// --- Process Batch Post Requests ---
1014+
// Now takes Vec<JsonValue> and returns Result<Vec<JsonValue>, PyErr>
1015+
async fn process_batch_post_requests(
1016+
client: Client,
1017+
url_path: String,
1018+
payloads_json: Vec<JsonValue>, // Takes Vec<JsonValue>
1019+
api_key: String,
1020+
api_base: String,
1021+
max_concurrent_requests: usize,
1022+
request_timeout_duration: Duration,
1023+
) -> Result<Vec<JsonValue>, PyErr> {
1024+
// Returns Vec<JsonValue>
1025+
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
1026+
let mut tasks = Vec::new();
1027+
let cancel_token = Arc::new(AtomicBool::new(false));
1028+
let total_payloads = payloads_json.len();
1029+
1030+
for (index, payload_item_json) in payloads_json.into_iter().enumerate() {
1031+
// Iterate over JsonValue
1032+
let client_clone = client.clone();
1033+
let api_key_clone = api_key.clone();
1034+
let api_base_clone = api_base.clone();
1035+
let url_path_clone = url_path.clone();
1036+
let semaphore_clone = Arc::clone(&semaphore);
1037+
let cancel_token_clone = Arc::clone(&cancel_token);
1038+
let individual_request_timeout = request_timeout_duration;
1039+
1040+
// payload_item_json is moved into its own task
1041+
tasks.push(tokio::spawn(async move {
1042+
let permit_guard =
1043+
acquire_permit_or_cancel(semaphore_clone, cancel_token_clone.clone()).await?;
1044+
1045+
let full_url = format!(
1046+
"{}/{}",
1047+
api_base_clone.trim_end_matches('/'),
1048+
url_path_clone.trim_start_matches('/')
1049+
);
1050+
1051+
let result = send_single_batch_post_request(
1052+
client_clone,
1053+
full_url,
1054+
payload_item_json, // Pass JsonValue
1055+
api_key_clone,
1056+
individual_request_timeout,
1057+
)
1058+
.await;
1059+
1060+
drop(permit_guard);
1061+
1062+
match result {
1063+
Ok(response_json_value) => Ok((index, response_json_value)), // Return with original index and JsonValue
1064+
Err(e) => {
1065+
cancel_token_clone.store(true, Ordering::SeqCst);
1066+
Err(e)
1067+
}
1068+
}
1069+
}));
1070+
}
1071+
1072+
let task_join_results = join_all(tasks).await;
1073+
let mut indexed_results: Vec<(usize, JsonValue)> = Vec::with_capacity(total_payloads); // Stores JsonValue
1074+
let mut first_error: Option<PyErr> = None;
1075+
1076+
for result in task_join_results {
1077+
// D is (usize, JsonValue)
1078+
if let Some(indexed_data_part) =
1079+
process_task_outcome(result, &mut first_error, &cancel_token)
1080+
{
1081+
indexed_results.push(indexed_data_part);
1082+
}
1083+
}
1084+
1085+
if let Some(err) = first_error {
1086+
return Err(err);
1087+
}
1088+
1089+
indexed_results.sort_by_key(|&(original_index, _)| original_index);
1090+
1091+
let final_results: Vec<JsonValue> = indexed_results.into_iter().map(|(_, val)| val).collect(); // Collect JsonValue
1092+
1093+
Ok(final_results)
1094+
}
1095+
8971096
// Helper function to process task results and manage errors
8981097
fn process_task_outcome<D>(
8991098
task_join_result: Result<Result<D, PyErr>, JoinError>, // Removed OwnedSemaphorePermit from here

bei-client/tests/test_client_embed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,26 @@ def test_embedding_high_volume_return_instant():
214214
)
215215

216216

217+
@pytest.mark.skipif(
218+
not EMBEDDINGS_REACHABLE, reason="Deployment is not reachable. Skipping test."
219+
)
220+
def test_batch_post():
221+
client = PerformanceClient(api_base=api_base_embed, api_key=api_key)
222+
223+
assert client.api_key == api_key
224+
225+
openai_request_embed = {"model": "my_model", "input": ["Hello world"]}
226+
227+
response = client.batch_post(
228+
url_path="/v1/embeddings",
229+
payloads=[openai_request_embed, openai_request_embed],
230+
max_concurrent_requests=1,
231+
)
232+
assert response is not None
233+
assert len(response) == 2
234+
assert response[0]
235+
236+
217237
@pytest.mark.skipif(
218238
not EMBEDDINGS_REACHABLE, reason="Deployment is not reachable. Skipping test."
219239
)

0 commit comments

Comments
 (0)