Skip to content

Commit 88f7fa0

Browse files
authored
Merge pull request #1691 from basetenlabs/bump-version-0.9.98
Release 0.9.98
2 parents 6180bb7 + fd551ed commit 88f7fa0

File tree

27 files changed

+470
-160
lines changed

27 files changed

+470
-160
lines changed

.github/workflows/main.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,29 @@ jobs:
3030
with:
3131
run_only_integration: false
3232

33+
report_to_slack:
34+
runs-on: ubuntu-22.04
35+
if: always() && github.ref == 'refs/heads/main'
36+
needs:
37+
- all-tests
38+
steps:
39+
- name: get-branch
40+
run: echo ${{ github.ref }}
41+
- name: show-slack-status
42+
uses: 8398a7/action-slack@v3
43+
with:
44+
status: custom
45+
fields: author, job, commit, repo
46+
custom_payload: |
47+
{
48+
attachments: [{
49+
color: "${{ needs.all-tests.result == 'failure' && 'danger' || 'good' }}",
50+
text: `Truss post-commit tests ${{ needs.all-tests.result }}: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}`,
51+
}]
52+
}
53+
env:
54+
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
55+
3356
publish-rc-to-pypi:
3457
needs: [detect-version-changed]
3558
if: ${{ !failure() && !cancelled() && needs.detect-version-changed.outputs.release_version == 'true' && needs.detect-version-changed.outputs.is_prerelease_version == 'true' }}

baseten-inference-client/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

baseten-inference-client/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "baseten_inference_client"
3-
version = "0.0.1-rc3"
3+
version = "0.0.1"
44
edition = "2021"
55

66
[dependencies]

baseten-inference-client/README.md

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ base_url_embed = "https://model-yqv0rjjw.api.baseten.co/environments/production/
2121
# base_url_embed = "https://api.openai.com" or "https://api.mixedbread.com"
2222
client = InferenceClient(base_url=base_url_embed, api_key=api_key)
2323
```
24-
25-
### Synchronous Embedding
24+
### Embeddings
25+
#### Synchronous Embedding
2626

2727
```python
2828
texts = ["Hello world", "Example text", "Another sample"]
@@ -58,12 +58,12 @@ if numpy_array.shape[0] > 0:
5858

5959
Note: The embed method is versatile and can be used with any embeddings service, e.g. OpenAI API embeddings, not just for Baseten deployments.
6060

61-
### Asynchronous Embedding
61+
#### Asynchronous Embedding
6262

6363
```python
6464
async def async_embed():
6565
texts = ["Async hello", "Async example"]
66-
response = await client.aembed(
66+
response = await client.async_embed(
6767
input=texts,
6868
model="my_model",
6969
batch_size=2,
@@ -76,8 +76,22 @@ async def async_embed():
7676
# asyncio.run(async_embed())
7777
```
7878

79-
### Synchronous Batch POST
79+
#### Embedding Benchmarks
80+
Comparison against `pip install openai` for `/v1/embeddings`. Tested with the `./scripts/compare_latency_openai.py` with mini_batch_size of 128, and 4 server-side replicas. Results with OpenAI similar, OpenAI allows a max mini_batch_size of 2048.
81+
82+
| Number of inputs / embeddings | Number of Tasks | InferenceClient (s) | AsyncOpenAI (s) | Speedup |
83+
|-------------------------------:|---------------:|---------------------:|----------------:|--------:|
84+
| 128 | 1 | 0.12 | 0.13 | 1.08× |
85+
| 512 | 4 | 0.14 | 0.21 | 1.50× |
86+
| 8 192 | 64 | 0.83 | 1.95 | 2.35× |
87+
| 131 072 | 1 024 | 4.63 | 39.07 | 8.44× |
88+
| 2 097 152 | 16 384 | 70.92 | 903.68 | 12.74× |
89+
90+
### Gerneral Batch POST
8091

92+
The batch_post method is generic. It can be used to send POST requests to any URL, not limited to Baseten endpoints. The input and output can be any JSON item.
93+
94+
#### Synchronous Batch POST
8195
```python
8296
payload1 = {"model": "my_model", "input": ["Batch request sample 1"]}
8397
payload2 = {"model": "my_model", "input": ["Batch request sample 2"]}
@@ -90,15 +104,12 @@ response1, response2 = client.batch_post(
90104
print("Batch POST responses:", response1, response2)
91105
```
92106

93-
Note: The batch_post method is generic. It can be used to send POST requests to any URL,
94-
not limited to Baseten endpoints.
95-
96-
### Asynchronous Batch POST
107+
#### Asynchronous Batch POST
97108

98109
```python
99110
async def async_batch_post():
100111
payload = {"model": "my_model", "input": ["Async batch sample"]}
101-
responses = await client.abatch_post(
112+
responses = await client.async_batch_post(
102113
url_path="/v1/embeddings",
103114
payloads=[payload, payload],
104115
max_concurrent_requests=4,
@@ -109,8 +120,10 @@ async def async_batch_post():
109120
# To run:
110121
# asyncio.run(async_batch_post())
111122
```
123+
### Reranking
124+
Reranking compatible with BEI or text-embeddings-inference.
112125

113-
### Synchronous Reranking
126+
#### Synchronous Reranking
114127

115128
```python
116129
query = "What is the best framework?"
@@ -127,13 +140,13 @@ for res in rerank_response.data:
127140
print(f"Index: {res.index} Score: {res.score}")
128141
```
129142

130-
### Asynchronous Reranking
143+
#### Asynchronous Reranking
131144

132145
```python
133146
async def async_rerank():
134147
query = "Async query sample"
135148
docs = ["Async doc1", "Async doc2"]
136-
response = await client.arerank(
149+
response = await client.async_rerank(
137150
query=query,
138151
texts=docs,
139152
return_text=True,
@@ -148,7 +161,9 @@ async def async_rerank():
148161
# asyncio.run(async_rerank())
149162
```
150163

151-
### Synchronous Classification
164+
### Classification
165+
Predicy (classification endpoint) compatible with BEI or text-embeddings-inference.
166+
#### Synchronous Classification
152167

153168
```python
154169
texts_to_classify = [
@@ -167,12 +182,11 @@ for group in classify_response.data:
167182
print(f"Label: {result.label}, Score: {result.score}")
168183
```
169184

170-
### Asynchronous Classification
171-
185+
#### Asynchronous Classification
172186
```python
173187
async def async_classify():
174188
texts = ["Async positive", "Async negative"]
175-
response = await client.aclassify(
189+
response = await client.async_classify(
176190
inputs=texts,
177191
batch_size=1,
178192
max_concurrent_requests=8,
@@ -187,28 +201,7 @@ async def async_classify():
187201
```
188202

189203

190-
## Development
191-
192-
```bash
193-
# Install prerequisites
194-
sudo apt-get install patchelf
195-
# Install cargo if not already installed.
196-
197-
# Set up a Python virtual environment
198-
python -m venv .venv
199-
source .venv/bin/activate
200-
201-
# Install development dependencies
202-
pip install maturin[patchelf] pytest requests numpy
203-
204-
# Build and install the Rust extension in development mode
205-
maturin develop
206-
cargo fmt
207-
# Run tests
208-
pytest tests
209-
```
210-
211-
## Error Handling
204+
### Error Handling
212205

213206
The client can raise several types of errors. Here's how to handle common ones:
214207

@@ -243,7 +236,28 @@ except requests.exceptions.HTTPError as e:
243236

244237
```
245238

246-
For asynchronous methods (`aembed`, `arerank`, `aclassify`, `abatch_post`), the same exceptions will be raised by the `await` call and can be caught using a `try...except` block within an `async def` function.
239+
For asynchronous methods (`async_embed`, `async_rerank`, `async_classify`, `async_batch_post`), the same exceptions will be raised by the `await` call and can be caught using a `try...except` block within an `async def` function.
240+
241+
## Development
242+
243+
```bash
244+
# Install prerequisites
245+
sudo apt-get install patchelf
246+
# Install cargo if not already installed.
247+
248+
# Set up a Python virtual environment
249+
python -m venv .venv
250+
source .venv/bin/activate
251+
252+
# Install development dependencies
253+
pip install maturin[patchelf] pytest requests numpy
254+
255+
# Build and install the Rust extension in development mode
256+
maturin develop
257+
cargo fmt
258+
# Run tests
259+
pytest tests
260+
```
247261

248262
## Contributions
249263
Feel free to contribute to this repo, tag @michaelfeil for review.

baseten-inference-client/baseten_inference_client.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class InferenceClient:
383383
"""
384384
...
385385

386-
async def aembed(
386+
async def async_embed(
387387
self,
388388
input: builtins.list[builtins.str],
389389
model: builtins.str,
@@ -415,12 +415,12 @@ class InferenceClient:
415415
requests.exceptions.HTTPError: If the request fails.
416416
417417
Example:
418-
>>> response = await client.aembed(["hello", "world"], model="model-id")
418+
>>> response = await client.async_embed(["hello", "world"], model="model-id")
419419
>>> print(response.data[0].embedding)
420420
"""
421421
...
422422

423-
async def arerank(
423+
async def async_rerank(
424424
self,
425425
query: builtins.str,
426426
texts: builtins.list[builtins.str],
@@ -454,13 +454,13 @@ class InferenceClient:
454454
requests.exceptions.HTTPError: If the request fails.
455455
456456
Example:
457-
>>> response = await client.arerank("find", ["doc1", "doc2"])
457+
>>> response = await client.async_rerank("find", ["doc1", "doc2"])
458458
>>> for result in response.data:
459459
... print(result.index, result.score)
460460
"""
461461
...
462462

463-
async def aclassify(
463+
async def async_classify(
464464
self,
465465
inputs: builtins.list[builtins.str],
466466
raw_scores: builtins.bool = False,
@@ -490,14 +490,14 @@ class InferenceClient:
490490
requests.exceptions.HTTPError: If the request fails.
491491
492492
Example:
493-
>>> response = await client.aclassify(["text1", "text2"])
493+
>>> response = await client.async_classify(["text1", "text2"])
494494
>>> for group in response.data:
495495
... for result in group:
496496
... print(result.label, result.score)
497497
"""
498498
...
499499

500-
async def abatch_post(
500+
async def async_batch_post(
501501
self,
502502
url_path: builtins.str,
503503
payloads: builtins.list[typing.Any],
@@ -521,7 +521,7 @@ class InferenceClient:
521521
requests.exceptions.HTTPError: If any underlying HTTP requests fail.
522522
523523
Example:
524-
>>> responses = await client.abatch_post("/v1/process_item", [{"data": "r1"}, {"data": "r2"}])
524+
>>> responses = await client.async_batch_post("/v1/process_item", [{"data": "r1"}, {"data": "r2"}])
525525
>>> for resp in responses:
526526
... print(resp)
527527
"""

baseten-inference-client/src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ impl InferenceClient {
411411
Python::with_gil(|py_gil| Ok(successful_response.into_py(py_gil)))
412412
}
413413

414-
#[pyo3(name = "aembed", signature = (input, model, encoding_format = None, dimensions = None, user = None, max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
415-
fn aembed<'py>(
414+
#[pyo3(name = "async_embed", signature = (input, model, encoding_format = None, dimensions = None, user = None, max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
415+
fn async_embed<'py>(
416416
&self,
417417
py: Python<'py>,
418418
input: Vec<String>,
@@ -513,8 +513,8 @@ impl InferenceClient {
513513
Python::with_gil(|py| Ok(successful_response.into_py(py)))
514514
}
515515

516-
#[pyo3(name = "arerank", signature = (query, texts, raw_scores = false, return_text = false, truncate = false, truncation_direction = "Right", max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
517-
fn arerank<'py>(
516+
#[pyo3(name = "async_rerank", signature = (query, texts, raw_scores = false, return_text = false, truncate = false, truncation_direction = "Right", max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
517+
fn async_rerank<'py>(
518518
&self,
519519
py: Python<'py>,
520520
query: String,
@@ -614,8 +614,8 @@ impl InferenceClient {
614614
Python::with_gil(|py| Ok(result_from_async_task?.into_py(py)))
615615
}
616616

617-
#[pyo3(name = "aclassify", signature = (inputs, raw_scores = false, truncate = false, truncation_direction = "Right", max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
618-
fn aclassify<'py>(
617+
#[pyo3(name = "async_classify", signature = (inputs, raw_scores = false, truncate = false, truncation_direction = "Right", max_concurrent_requests = DEFAULT_CONCURRENCY, batch_size = DEFAULT_BATCH_SIZE, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
618+
fn async_classify<'py>(
619619
&self,
620620
py: Python<'py>,
621621
inputs: Vec<String>,
@@ -668,7 +668,7 @@ impl InferenceClient {
668668
if payloads.is_empty() {
669669
return Err(PyValueError::new_err("Payloads list cannot be empty"));
670670
}
671-
InferenceClient::validate_concurrency_parameters(max_concurrent_requests, 1)?; // Batch size is effectively 1
671+
InferenceClient::validate_concurrency_parameters(max_concurrent_requests, 1000)?; // sent batch size to 1000 to allow higher batch
672672
let timeout_duration = InferenceClient::validate_and_get_timeout_duration(timeout_s)?;
673673

674674
// Depythonize all payloads in the current thread (GIL is held)
@@ -737,8 +737,8 @@ impl InferenceClient {
737737
Ok(py_object_list.into())
738738
}
739739

740-
#[pyo3(name = "abatch_post", signature = (url_path, payloads, max_concurrent_requests = DEFAULT_CONCURRENCY, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
741-
fn abatch_post<'py>(
740+
#[pyo3(name = "async_batch_post", signature = (url_path, payloads, max_concurrent_requests = DEFAULT_CONCURRENCY, timeout_s = DEFAULT_REQUEST_TIMEOUT_S))]
741+
fn async_batch_post<'py>(
742742
&self,
743743
py: Python<'py>,
744744
url_path: String,
@@ -749,7 +749,7 @@ impl InferenceClient {
749749
if payloads.is_empty() {
750750
return Err(PyValueError::new_err("Payloads list cannot be empty"));
751751
}
752-
InferenceClient::validate_concurrency_parameters(max_concurrent_requests, 1)?; // Batch size is effectively 1
752+
InferenceClient::validate_concurrency_parameters(max_concurrent_requests, 1000)?; // sent batch size to 1000 to allow higher batch
753753
let timeout_duration = InferenceClient::validate_and_get_timeout_duration(timeout_s)?;
754754

755755
// Depythonize all payloads in the current thread (GIL is held by `py` argument)

baseten-inference-client/tests/test_bindings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ def test_baseten_inference_client_bindings_basic_test():
22
from baseten_inference_client import InferenceClient
33

44
InferenceClient.embed
5-
InferenceClient.aembed
5+
InferenceClient.async_embed
66
InferenceClient.rerank
7-
InferenceClient.arerank
7+
InferenceClient.async_rerank
88
InferenceClient.classify
9-
InferenceClient.aclassify
9+
InferenceClient.async_classify
1010
InferenceClient.batch_post
11-
InferenceClient.abatch_post
11+
InferenceClient.async_batch_post

baseten-inference-client/tests/test_client_embed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def embed_job(start_time):
274274
async def test_embed_async():
275275
client = InferenceClient(base_url=base_url_embed, api_key=api_key)
276276

277-
response = await client.aembed(
277+
response = await client.async_embed(
278278
["Hello world", "Hello world 2"],
279279
model="my_model",
280280
batch_size=1,
@@ -296,7 +296,7 @@ async def test_embed_async():
296296
async def test_classify_async():
297297
client = InferenceClient(base_url=base_url_rerank, api_key=api_key)
298298

299-
response = await client.aclassify(
299+
response = await client.async_classify(
300300
inputs=["who, who?", "Paris france"], batch_size=2, max_concurrent_requests=2
301301
)
302302
assert response is not None
@@ -313,7 +313,7 @@ async def test_classify_async():
313313
async def test_rerank_async():
314314
client = InferenceClient(base_url=base_url_rerank, api_key=api_key)
315315

316-
response = await client.arerank(
316+
response = await client.async_rerank(
317317
query="Who let the dogs out?",
318318
texts=["who, who?", "Paris france"],
319319
batch_size=2,

0 commit comments

Comments
 (0)