Skip to content

Add Azure OpenAI Content Filter Streaming Support #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ internal class ChatMessageAssembler {
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null
private val toolCallsAssemblers = mutableMapOf<Int, ToolCallAssembler>()
private var chatContentFilterOffsets = mutableListOf<ContentFilterOffsets>()
private var chatContentFilterResults = mutableListOf<ContentFilterResults>()

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chunk.delta.run {
chunk.delta?.run {
role?.let { chatRole = it }
content?.let { chatContent.append(it) }
functionCall?.let { call ->
Expand All @@ -30,6 +32,12 @@ internal class ChatMessageAssembler {
assembler.merge(toolCall)
}
}
chunk.contentFilterOffsets?.also {
chatContentFilterOffsets.add(it)
}
chunk.contentFilterResults?.also {
chatContentFilterResults.add(it)
}
return this
}

Expand All @@ -39,6 +47,8 @@ internal class ChatMessageAssembler {
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
this.contentFilterOffsets = chatContentFilterOffsets
this.contentFilterResults = chatContentFilterResults
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import com.aallam.openai.api.chat.ChatChunk
import com.aallam.openai.api.chat.ChatDelta
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.ContentFilterOffsets
import com.aallam.openai.api.chat.ContentFilterResult
import com.aallam.openai.api.chat.ContentFilterResults
import com.aallam.openai.api.core.FinishReason
import com.aallam.openai.client.extension.mergeToChatMessage
import kotlin.test.Test
Expand All @@ -20,6 +23,8 @@ class TestChatChunk {
role = ChatRole(role = "assistant"),
content = ""
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -28,6 +33,8 @@ class TestChatChunk {
role = null,
content = "The"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -36,6 +43,8 @@ class TestChatChunk {
role = null,
content = " World"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -44,6 +53,8 @@ class TestChatChunk {
role = null,
content = " Series"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -52,6 +63,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -60,6 +73,8 @@ class TestChatChunk {
role = null,
content = " "
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -68,6 +83,8 @@ class TestChatChunk {
role = null,
content = "202"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -76,6 +93,8 @@ class TestChatChunk {
role = null,
content = "0"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -84,6 +103,8 @@ class TestChatChunk {
role = null,
content = " is"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -92,6 +113,8 @@ class TestChatChunk {
role = null,
content = " being held"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -100,6 +123,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -108,6 +133,8 @@ class TestChatChunk {
role = null,
content = " Texas"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -116,6 +143,8 @@ class TestChatChunk {
role = null,
content = "."
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -124,6 +153,24 @@ class TestChatChunk {
role = null,
content = null
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = FinishReason(value = "stop")
),
ChatChunk(
index = 0,
delta = null,
contentFilterOffsets = ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
),
contentFilterResults = ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
),
finishReason = FinishReason(value = "stop")
)
)
Expand All @@ -132,6 +179,21 @@ class TestChatChunk {
role = ChatRole.Assistant,
content = "The World Series in 2020 is being held in Texas.",
name = null,
contentFilterResults = listOf(
ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
)
),
contentFilterOffsets = listOf(
ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
)
),
)
assertEquals(chatMessage, message)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.ChatCompletionChunk
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.client.internal.JsonLenient
import com.aallam.openai.client.internal.TestFileSystem
import com.aallam.openai.client.internal.testFilePath
import kotlin.test.Test
import okio.buffer

class TestChatCompletionChunk {
@Test
fun testContentFilterDeserialization() {
val json = FileSource(path = testFilePath("json/azureContentFilterChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}

@Test
fun testDeserialization() {
val json = FileSource(path = testFilePath("json/chatChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"choices": [
{
"content_filter_offsets": {
"check_offset": 33188,
"start_offset": 33188,
"end_offset": 33557
},
"content_filter_results": {
"hate": {
"filtered": false,
"severity": "safe"
},
"self_harm": {
"filtered": false,
"severity": "safe"
},
"sexual": {
"filtered": false,
"severity": "safe"
},
"violence": {
"filtered": false,
"severity": "safe"
}
},
"finish_reason": null,
"index": 0
}
],
"created": 0,
"id": "",
"model": "",
"object": ""
}
16 changes: 16 additions & 0 deletions openai-client/src/commonTest/resources/json/chatChunk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"choices": [
{
"delta": {
"content": " engineering"
},
"finish_reason": null,
"index": 0
}
],
"created": 1716855566,
"id": "chatcmpl-9TeqkT3BJs5zXQq12b204deXcY5nj",
"model": "gpt-4o-2024-05-13",
"object": "chat.completion.chunk",
"system_fingerprint": "fp_5f4bad809a"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.api.chat;

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.FinishReason
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -19,7 +18,17 @@ public data class ChatChunk(
/**
* The generated chat message.
*/
@SerialName("delta") public val delta: ChatDelta,
@SerialName("delta") public val delta: ChatDelta? = null,
Copy link
Contributor Author

@rasharab rasharab May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delta must be nullable for content filtering.


/**
* Azure content filter offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: ContentFilterOffsets? = null,

/**
* Azure content filter results
*/
@SerialName("content_filter_results") public val contentFilterResults: ContentFilterResults? = null,

/**
* The reason why OpenAI stopped generating.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ public data class ChatMessage(
* Tool call ID.
*/
@SerialName("tool_call_id") public val toolCallId: ToolId? = null,

/**
* Azure Content Filter Results
*/
@SerialName("content_filter_results") public val contentFilterResults: List<ContentFilterResults>? = null,

/**
* Azure Content Filter Offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: List<ContentFilterOffsets>? = null,
) {

public constructor(
Expand All @@ -54,13 +64,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { TextContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

public constructor(
Expand All @@ -70,13 +84,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { ListContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

val content: String?
Expand Down Expand Up @@ -282,6 +300,16 @@ public class ChatMessageBuilder {
*/
public var toolCalls: List<ToolCall>? = null

/**
* Azure content filter offsets
*/
public var contentFilterOffsets: List<ContentFilterOffsets>? = null

/**
* Azure content filter results
*/
public var contentFilterResults: List<ContentFilterResults>? = null

/**
* Tool call ID.
*/
Expand Down Expand Up @@ -313,6 +341,8 @@ public class ChatMessageBuilder {
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)
}
}
Expand Down
Loading