Skip to content

Commit a2ce122

Browse files
authored
feat(batch): add batch APIs (#334)
1 parent 3ecab2e commit a2ce122

File tree

23 files changed

+615
-7
lines changed

23 files changed

+615
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
### Added
44
- **vector-stores**: add vector stores APIs (#324)
5+
- **batch**: add batch APIs (#334)
56

67
### Fixed
78
- **chat**: enhance flow cancel capability (#333)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.aallam.openai.client
2+
3+
import com.aallam.openai.api.batch.Batch
4+
import com.aallam.openai.api.batch.BatchId
5+
import com.aallam.openai.api.batch.BatchRequest
6+
import com.aallam.openai.api.core.RequestOptions
7+
8+
/**
9+
* Create large batches of API requests for asynchronous processing.
10+
* The Batch API returns completions within 24 hours for a 50% discount.
11+
*/
12+
public interface Batch {
13+
14+
/**
15+
* Creates and executes a batch from an uploaded file of requests.
16+
*/
17+
public suspend fun batch(request: BatchRequest, requestOptions: RequestOptions? = null): Batch
18+
19+
/**
20+
* Retrieves a batch.
21+
*/
22+
public suspend fun batch(id: BatchId, requestOptions: RequestOptions? = null): Batch?
23+
24+
/**
25+
* Cancels an in-progress batch.
26+
*/
27+
public suspend fun cancel(id: BatchId, requestOptions: RequestOptions? = null): Batch?
28+
29+
/**
30+
* List your organization's batches.
31+
*
32+
* @param after A cursor for use in pagination. After is an object ID that defines your place in the list.
33+
* For instance, if you make a list request and receive 100 objects, ending with obj_foo, your later call can
34+
* include after=obj_foo to fetch the next page of the list.
35+
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default
36+
* is 20.
37+
*/
38+
public suspend fun batches(
39+
after: BatchId? = null,
40+
limit: Int? = null,
41+
requestOptions: RequestOptions? = null
42+
): List<Batch>
43+
}

openai-client/src/commonMain/kotlin/com.aallam.openai.client/VectorStores.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.aallam.openai.client
22

33
import com.aallam.openai.api.BetaOpenAI
4+
import com.aallam.openai.api.batch.BatchId
45
import com.aallam.openai.api.core.RequestOptions
56
import com.aallam.openai.api.core.SortOrder
67
import com.aallam.openai.api.core.Status

openai-client/src/commonMain/kotlin/com.aallam.openai.client/internal/OpenAIApi.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ internal class OpenAIApi(
2828
Runs by RunsApi(requester),
2929
Messages by MessagesApi(requester),
3030
VectorStores by VectorStoresApi(requester),
31+
Batch by BatchApi(requester),
3132
Closeable by requester

openai-client/src/commonMain/kotlin/com.aallam.openai.client/internal/api/ApiPath.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ internal object ApiPath {
2222
const val Assistants = "assistants"
2323
const val Threads = "threads"
2424
const val VectorStores = "vector_stores"
25+
const val Batches = "batches"
2526
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.aallam.openai.client.internal.api
2+
3+
import com.aallam.openai.api.batch.BatchId
4+
import com.aallam.openai.api.batch.BatchRequest
5+
import com.aallam.openai.api.core.PaginatedList
6+
import com.aallam.openai.api.core.RequestOptions
7+
import com.aallam.openai.api.exception.OpenAIAPIException
8+
import com.aallam.openai.client.Batch
9+
import com.aallam.openai.client.internal.extension.beta
10+
import com.aallam.openai.client.internal.extension.requestOptions
11+
import com.aallam.openai.client.internal.http.HttpRequester
12+
import com.aallam.openai.client.internal.http.perform
13+
import io.ktor.client.call.*
14+
import io.ktor.client.request.*
15+
import io.ktor.client.statement.*
16+
import io.ktor.http.*
17+
import com.aallam.openai.api.batch.Batch as BatchObject
18+
19+
/**
20+
* Implementation of [Batch].
21+
*/
22+
internal class BatchApi(val requester: HttpRequester) : Batch {
23+
24+
override suspend fun batch(
25+
request: BatchRequest,
26+
requestOptions: RequestOptions?
27+
): BatchObject {
28+
return requester.perform {
29+
it.post {
30+
url(path = ApiPath.Batches)
31+
setBody(request)
32+
contentType(ContentType.Application.Json)
33+
requestOptions(requestOptions)
34+
}.body()
35+
}
36+
}
37+
38+
override suspend fun batch(id: BatchId, requestOptions: RequestOptions?): BatchObject? {
39+
try {
40+
return requester.perform<HttpResponse> {
41+
it.get {
42+
url(path = "${ApiPath.Batches}/${id.id}")
43+
requestOptions(requestOptions)
44+
}
45+
}.body()
46+
} catch (e: OpenAIAPIException) {
47+
if (e.statusCode == HttpStatusCode.NotFound.value) return null
48+
throw e
49+
}
50+
}
51+
52+
override suspend fun cancel(id: BatchId, requestOptions: RequestOptions?): BatchObject? {
53+
val response = requester.perform<HttpResponse> {
54+
it.post {
55+
url(path = "${ApiPath.Batches}/${id.id}/cancel")
56+
requestOptions(requestOptions)
57+
}
58+
}
59+
return if (response.status == HttpStatusCode.NotFound) null else response.body()
60+
}
61+
62+
override suspend fun batches(
63+
after: BatchId?,
64+
limit: Int?,
65+
requestOptions: RequestOptions?
66+
): PaginatedList<BatchObject> {
67+
return requester.perform {
68+
it.get {
69+
url {
70+
path(ApiPath.Batches)
71+
limit?.let { parameter("limit", it) }
72+
after?.let { parameter("after", it.id) }
73+
}
74+
beta("assistants", 2)
75+
requestOptions(requestOptions)
76+
}.body()
77+
}
78+
}
79+
}

openai-client/src/commonMain/kotlin/com.aallam.openai.client/internal/api/VectorStoresApi.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.aallam.openai.client.internal.api
22

3+
import com.aallam.openai.api.batch.BatchId
34
import com.aallam.openai.api.core.*
45
import com.aallam.openai.api.exception.OpenAIAPIException
56
import com.aallam.openai.api.file.FileId

openai-client/src/commonMain/kotlin/com.aallam.openai.client/internal/http/HttpTransport.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ internal class HttpTransport(private val httpClient: HttpClient) : HttpRequester
6262
val error = response.body<OpenAIError>()
6363
return when(status) {
6464
429 -> RateLimitException(status, error, exception)
65-
400, 404, 415 -> InvalidRequestException(status, error, exception)
65+
400, 404, 409, 415 -> InvalidRequestException(status, error, exception)
6666
401 -> AuthenticationException(status, error, exception)
6767
403 -> PermissionException(status, error, exception)
6868
else -> UnknownAPIException(status, error, exception)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package com.aallam.openai.client
2+
3+
import com.aallam.openai.api.batch.*
4+
import com.aallam.openai.api.batch.Batch
5+
import com.aallam.openai.api.chat.ChatCompletion
6+
import com.aallam.openai.api.chat.ChatCompletionRequest
7+
import com.aallam.openai.api.chat.ChatMessage
8+
import com.aallam.openai.api.chat.TextContent
9+
import com.aallam.openai.api.core.Endpoint
10+
import com.aallam.openai.api.core.Role
11+
import com.aallam.openai.api.file.Purpose
12+
import com.aallam.openai.api.file.fileSource
13+
import com.aallam.openai.api.file.fileUpload
14+
import com.aallam.openai.api.model.ModelId
15+
import com.aallam.openai.client.internal.JsonLenient
16+
import com.aallam.openai.client.internal.asSource
17+
import kotlinx.serialization.encodeToString
18+
import kotlinx.serialization.json.Json
19+
import kotlinx.serialization.json.decodeFromJsonElement
20+
import kotlin.test.*
21+
22+
class TestBatches : TestOpenAI() {
23+
24+
25+
@Test
26+
fun batchSerialization() {
27+
val json = """
28+
{
29+
"id": "batch_0mhGzcpyyQnS1T38bFI4vgMN",
30+
"object": "batch",
31+
"endpoint": "/v1/chat/completions",
32+
"errors": null,
33+
"input_file_id": "file-CmkZMEEBGbVB0YMzuKMjCT0C",
34+
"completion_window": "24h",
35+
"status": "validating",
36+
"output_file_id": null,
37+
"error_file_id": null,
38+
"created_at": 1714347843,
39+
"in_progress_at": null,
40+
"expires_at": 1714434243,
41+
"finalizing_at": null,
42+
"completed_at": null,
43+
"failed_at": null,
44+
"expired_at": null,
45+
"cancelling_at": null,
46+
"cancelled_at": null,
47+
"request_counts": {
48+
"total": 0,
49+
"completed": 0,
50+
"failed": 0
51+
},
52+
"metadata": null
53+
}
54+
""".trimIndent()
55+
56+
val batch = JsonLenient.decodeFromString<Batch>(json)
57+
assertEquals("batch_0mhGzcpyyQnS1T38bFI4vgMN", batch.id.id)
58+
assertEquals("/v1/chat/completions", batch.endpoint.path)
59+
assertEquals("24h", batch.completionWindow?.value)
60+
}
61+
62+
@Test
63+
fun batches() = test {
64+
val systemPrompt =
65+
"Your goal is to extract movie categories from movie descriptions, as well as a 1-sentence summary for these movies."
66+
val descriptions = listOf(
67+
"Two imprisoned men bond over a number of years, finding solace and eventual redemption through acts of common decency.",
68+
"An organized crime dynasty's aging patriarch transfers control of his clandestine empire to his reluctant son.",
69+
)
70+
71+
val requestInputs = descriptions.mapIndexed { index, input ->
72+
RequestInput(
73+
customId = CustomId("task-$index"),
74+
method = Method.Post,
75+
url = "/v1/chat/completions",
76+
body = ChatCompletionRequest(
77+
model = ModelId("gpt-3.5-turbo"),
78+
messages = listOf(
79+
ChatMessage(
80+
role = Role.System,
81+
messageContent = TextContent(systemPrompt)
82+
),
83+
ChatMessage(
84+
role = Role.User,
85+
messageContent = TextContent(input)
86+
)
87+
)
88+
)
89+
)
90+
}
91+
92+
93+
val jsonl = buildJsonlFile(requestInputs)
94+
val fileRequest = fileUpload {
95+
file = fileSource {
96+
name = "input.jsonl"
97+
source = jsonl.asSource()
98+
}
99+
purpose = Purpose("batch")
100+
}
101+
val batchFile = openAI.file(fileRequest)
102+
103+
val request = batchRequest {
104+
inputFileId = batchFile.id
105+
endpoint = Endpoint.Completions
106+
completionWindow = CompletionWindow.TwentyFourHours
107+
}
108+
109+
val batch = openAI.batch(request = request)
110+
val fetchedBatch = openAI.batch(id = batch.id)
111+
assertEquals(batch.id, fetchedBatch?.id)
112+
113+
val batches = openAI.batches()
114+
assertContains(batches.map { it.id }, batch.id)
115+
116+
openAI.cancel(id = batch.id)
117+
openAI.delete(fileId = batchFile.id)
118+
}
119+
120+
private fun buildJsonlFile(requests: List<RequestInput>, json: Json = Json): String = buildString {
121+
for (request in requests) {
122+
appendLine(json.encodeToString(request))
123+
}
124+
}
125+
126+
@Test
127+
fun testDecodeOutput() = test {
128+
val output = """
129+
{"id": "batch_req_gS7NOjY66SR4zsPAsZTLCQfy", "custom_id": "task-0", "response": {"status_code": 200, "request_id": "ab750cd57ec6610df04703802ba65f21", "body": {"id": "chatcmpl-9K21h6ZU0DGFi9FA4aC2T4Gd4SfKU", "object": "chat.completion", "created": 1714561377, "model": "gpt-3.5-turbo-0125", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Category: Drama\n\nSummary: Two imprisoned men form a strong bond and find redemption through acts of kindness and decency."}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 57, "completion_tokens": 23, "total_tokens": 80}, "system_fingerprint": "fp_3b956da36b"}}, "error": null}
130+
{"id": "batch_req_iTjKmQps1zBqDTtXH9cft7ck", "custom_id": "task-1", "response": {"status_code": 200, "request_id": "75b9ca6b6d47baa61e3a3830968ca63a", "body": {"id": "chatcmpl-9K21h3Mv2zlWvj3S4e1YHlXOPWTsI", "object": "chat.completion", "created": 1714561377, "model": "gpt-3.5-turbo-0125", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Movie categories: Crime, Drama\n\nSummary: A reluctant heir must take control of an organized crime empire from his aging father in this intense drama."}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 54, "completion_tokens": 29, "total_tokens": 83}, "system_fingerprint": "fp_3b956da36b"}}, "error": null}
131+
"""
132+
.trimIndent()
133+
.encodeToByteArray() // simulate reading from a file using download(fileId)
134+
135+
val outputs = decodeOutput(output)
136+
assertEquals(2, outputs.size)
137+
assertNotNull(outputs.find { it.customId == CustomId("task-0") })
138+
assertNotNull(outputs.find { it.customId == CustomId("task-1") })
139+
140+
val response = outputs.first().response ?: fail("response is null")
141+
assertEquals(200, response.statusCode)
142+
val completion = JsonLenient.decodeFromJsonElement<ChatCompletion>(response.body)
143+
assertNotNull(completion.choices.first().message.content)
144+
}
145+
146+
private fun decodeOutput(output: ByteArray): List<RequestOutput> {
147+
return output.decodeToString().lines().map { Json.decodeFromString<RequestOutput>(it) }
148+
}
149+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package com.aallam.openai.api.batch
2+
3+
import com.aallam.openai.api.BetaOpenAI
4+
import com.aallam.openai.api.core.Endpoint
5+
import com.aallam.openai.api.core.PaginatedList
6+
import com.aallam.openai.api.core.Status
7+
import com.aallam.openai.api.exception.OpenAIErrorDetails
8+
import com.aallam.openai.api.file.FileId
9+
import kotlinx.serialization.SerialName
10+
import kotlinx.serialization.Serializable
11+
12+
/**
13+
* Represents a batch object.
14+
*/
15+
@BetaOpenAI
16+
@Serializable
17+
public data class Batch(
18+
/** Unique identifier for the batch. */
19+
@SerialName("id") public val id: BatchId,
20+
21+
/** The OpenAI API endpoint used by the batch. */
22+
@SerialName("endpoint") public val endpoint: Endpoint,
23+
24+
/** Container for any errors occurred during batch processing. */
25+
@SerialName("errors") public val errors: PaginatedList<OpenAIErrorDetails>?,
26+
27+
/** Identifier of the input file for the batch. */
28+
@SerialName("input_file_id") public val inputFileId: FileId? = null,
29+
30+
/** Time frame within which the batch should be processed. */
31+
@SerialName("completion_window") public val completionWindow: CompletionWindow? = null,
32+
33+
/** Current processing status of the batch. */
34+
@SerialName("status") public val status: Status? = null,
35+
36+
/** Identifier of the output file containing successfully executed requests. */
37+
@SerialName("output_file_id") public val outputFileId: FileId? = null,
38+
39+
/** Identifier of the error file containing outputs of requests with errors. */
40+
@SerialName("error_file_id") public val errorFileId: FileId? = null,
41+
42+
/** Unix timestamp for when the batch was created. */
43+
@SerialName("created_at") public val createdAt: Long? = null,
44+
45+
/** Unix timestamp for when the batch processing started. */
46+
@SerialName("in_progress_at") public val inProgressAt: Long? = null,
47+
48+
/** Unix timestamp for when the batch will expire. */
49+
@SerialName("expires_at") public val expiresAt: Long? = null,
50+
51+
/** Unix timestamp for when the batch started finalizing. */
52+
@SerialName("finalizing_at") public val finalizingAt: Long? = null,
53+
54+
/** Unix timestamp for when the batch was completed. */
55+
@SerialName("completed_at") public val completedAt: Long? = null,
56+
57+
/** Unix timestamp for when the batch failed. */
58+
@SerialName("failed_at") public val failedAt: Long? = null,
59+
60+
/** Unix timestamp for when the batch expired. */
61+
@SerialName("expired_at") public val expiredAt: Long? = null,
62+
63+
/** Unix timestamp for when the batch started cancelling. */
64+
@SerialName("cancelling_at") public val cancellingAt: Long? = null,
65+
66+
/** Unix timestamp for when the batch was cancelled. */
67+
@SerialName("cancelled_at") public val cancelledAt: Long? = null,
68+
69+
/** Container for the counts of requests by their status. */
70+
@SerialName("request_counts") public val requestCounts: RequestCounts? = null,
71+
72+
/** Metadata associated with the batch as key-value pairs. */
73+
@SerialName("metadata") public val metadata: Map<String, String>? = null
74+
)

0 commit comments

Comments
 (0)