From 820ba02133a5bb005c7a2980440a0ead1f359a5d Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Paz Date: Wed, 18 Sep 2024 12:46:45 -0400 Subject: [PATCH] Add support for allowed_function_names A note about `@SerialName("allowed_function_names")`: This is not strictly necessary, the backend will parse it if it's camel case too. We should eventually remove all unnecessary `@SerialName` declarations, but for now, and to keep consistency, I'm adding it to this declaration. --- .../google/firebase/vertexai/common/client/Types.kt | 5 ++++- .../firebase/vertexai/internal/util/conversions.kt | 3 ++- .../firebase/vertexai/type/FunctionCallingConfig.kt | 5 ++++- .../com/google/firebase/vertexai/type/ToolConfig.kt | 9 +++++++-- .../firebase/vertexai/common/APIControllerTests.kt | 12 ++++++++---- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt index b16c662d360..054d221a0bb 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt @@ -47,7 +47,10 @@ internal data class ToolConfig( ) @Serializable -internal data class FunctionCallingConfig(val mode: Mode) { +internal data class FunctionCallingConfig( + val mode: Mode, + @SerialName("allowed_function_names") val allowedFunctionNames: List? = null +) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index 0a67271b3a1..5a0cb4cadbe 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -137,7 +137,8 @@ internal fun ToolConfig.toInternal() = com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.AUTO FunctionCallingConfig.Mode.NONE -> com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.NONE - } + }, + functionCallingConfig.allowedFunctionNames ) ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index 14b32e293c4..8517164dce2 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -21,8 +21,11 @@ package com.google.firebase.vertexai.type * calling predictions or disable them. * * @param mode The function calling mode of the model + * @param allowedFunctionNames Function names to call. Only set when the [Mode.ANY]. Function names + * should match [FunctionDeclaration.name]. With [Mode.ANY], model will predict a function call from + * the set of function names provided. */ -class FunctionCallingConfig(val mode: Mode) { +class FunctionCallingConfig(val mode: Mode, val allowedFunctionNames: List? = null) { /** Configuration for dictating when the model should call the attached function. */ enum class Mode { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index 9575cb18844..609d930e3e6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -27,7 +27,12 @@ class ToolConfig(val functionCallingConfig: FunctionCallingConfig) { companion object { /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */ fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE)) - /** Shorthand to construct a ToolConfig that restricts the model to always call some function */ - fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY)) + /** + * Shorthand to construct a ToolConfig that restricts the model to always call some function. + * You can optionally [allowedFunctionNames] to restrict the model to only call these functions. + * See [FunctionCallingConfig] for more information. + */ + fun always(allowedFunctionNames: List? = null): ToolConfig = + ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY, allowedFunctionNames)) } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 9ba3a429157..b683c1ba742 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -185,10 +185,12 @@ internal class RequestFormatTests { contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), toolConfig = ToolConfig( - functionCallingConfig = - FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO) - ), - ) + FunctionCallingConfig( + mode = FunctionCallingConfig.Mode.ANY, + allowedFunctionNames = listOf("allowedFunctionName") + ) + ) + ), ) .collect { channel.close() } } @@ -196,6 +198,8 @@ internal class RequestFormatTests { val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode" + requestBodyAsText shouldContainJsonKey + "tool_config.function_calling_config.allowed_function_names" } @Test