diff --git a/FirebaseAI/Sources/Chat.swift b/FirebaseAI/Sources/Chat.swift index 1aa2c3490c7..01d0b4a0056 100644 --- a/FirebaseAI/Sources/Chat.swift +++ b/FirebaseAI/Sources/Chat.swift @@ -146,11 +146,13 @@ public final class Chat: Sendable { for aggregate in chunks { // Loop through all the parts, aggregating the text and adding the images. for part in aggregate.parts { - switch part { - case let textPart as TextPart: - combinedText += textPart.text + guard !part.isThought else { + continue + } - default: + if let textPart = part as? TextPart { + combinedText += textPart.text + } else { // Don't combine it, just add to the content. If there's any text pending, add that as // a part. if !combinedText.isEmpty { diff --git a/FirebaseAI/Sources/GenerateContentResponse.swift b/FirebaseAI/Sources/GenerateContentResponse.swift index b0e348d2192..19e4d797419 100644 --- a/FirebaseAI/Sources/GenerateContentResponse.swift +++ b/FirebaseAI/Sources/GenerateContentResponse.swift @@ -66,12 +66,10 @@ public struct GenerateContentResponse: Sendable { return nil } let textValues: [String] = candidate.content.parts.compactMap { part in - switch part { - case let textPart as TextPart: - return textPart.text - default: + guard let textPart = part as? TextPart, !part.isThought else { return nil } + return textPart.text } guard textValues.count > 0 else { AILog.error( @@ -89,12 +87,10 @@ public struct GenerateContentResponse: Sendable { return [] } return candidate.content.parts.compactMap { part in - switch part { - case let functionCallPart as FunctionCallPart: - return functionCallPart - default: + guard let functionCallPart = part as? FunctionCallPart, !part.isThought else { return nil } + return functionCallPart } } @@ -107,7 +103,12 @@ public struct GenerateContentResponse: Sendable { """) return [] } - return candidate.content.parts.compactMap { $0 as? InlineDataPart } + return candidate.content.parts.compactMap { part in + guard let inlineDataPart = part as? InlineDataPart, !part.isThought else { + return nil + } + return inlineDataPart + } } /// Initializer for SwiftUI previews or tests. diff --git a/FirebaseAI/Sources/ModelContent.swift b/FirebaseAI/Sources/ModelContent.swift index 7d82bd76445..7fdd7d428c0 100644 --- a/FirebaseAI/Sources/ModelContent.swift +++ b/FirebaseAI/Sources/ModelContent.swift @@ -31,19 +31,31 @@ extension [ModelContent] { } } -/// A type describing data in media formats interpretable by an AI model. Each generative AI -/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value -/// may comprise multiple heterogeneous ``Part``s. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public struct ModelContent: Equatable, Sendable { - enum InternalPart: Equatable, Sendable { +struct InternalPart: Equatable, Sendable { + enum OneOfData: Equatable, Sendable { case text(String) - case inlineData(mimetype: String, Data) - case fileData(mimetype: String, uri: String) + case inlineData(InlineData) + case fileData(FileData) case functionCall(FunctionCall) case functionResponse(FunctionResponse) } + let data: OneOfData + + let isThought: Bool? + + init(_ data: OneOfData, isThought: Bool?) { + self.data = data + self.isThought = isThought + } +} + +/// A type describing data in media formats interpretable by an AI model. Each generative AI +/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value +/// may comprise multiple heterogeneous ``Part``s. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ModelContent: Equatable, Sendable { /// The role of the entity creating the ``ModelContent``. For user-generated client requests, /// for example, the role is `user`. public let role: String? @@ -52,17 +64,17 @@ public struct ModelContent: Equatable, Sendable { public var parts: [any Part] { var convertedParts = [any Part]() for part in internalParts { - switch part { + switch part.data { case let .text(text): - convertedParts.append(TextPart(text)) - case let .inlineData(mimetype, data): - convertedParts.append(InlineDataPart(data: data, mimeType: mimetype)) - case let .fileData(mimetype, uri): - convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype)) + convertedParts.append(TextPart(text, isThought: part.isThought)) + case let .inlineData(inlineData): + convertedParts.append(InlineDataPart(inlineData, isThought: part.isThought)) + case let .fileData(fileData): + convertedParts.append(FileDataPart(fileData, isThought: part.isThought)) case let .functionCall(functionCall): - convertedParts.append(FunctionCallPart(functionCall)) + convertedParts.append(FunctionCallPart(functionCall, isThought: part.isThought)) case let .functionResponse(functionResponse): - convertedParts.append(FunctionResponsePart(functionResponse)) + convertedParts.append(FunctionResponsePart(functionResponse, isThought: part.isThought)) } } return convertedParts @@ -78,17 +90,28 @@ public struct ModelContent: Equatable, Sendable { for part in parts { switch part { case let textPart as TextPart: - convertedParts.append(.text(textPart.text)) + convertedParts.append(InternalPart(.text(textPart.text), isThought: textPart._isThought)) case let inlineDataPart as InlineDataPart: - let inlineData = inlineDataPart.inlineData - convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) + convertedParts.append( + InternalPart(.inlineData(inlineDataPart.inlineData), isThought: inlineDataPart._isThought) + ) case let fileDataPart as FileDataPart: - let fileData = fileDataPart.fileData - convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) + convertedParts.append( + InternalPart(.fileData(fileDataPart.fileData), isThought: fileDataPart._isThought) + ) case let functionCallPart as FunctionCallPart: - convertedParts.append(.functionCall(functionCallPart.functionCall)) + convertedParts.append( + InternalPart( + .functionCall(functionCallPart.functionCall), isThought: functionCallPart._isThought + ) + ) case let functionResponsePart as FunctionResponsePart: - convertedParts.append(.functionResponse(functionResponsePart.functionResponse)) + convertedParts.append( + InternalPart( + .functionResponse(functionResponsePart.functionResponse), + isThought: functionResponsePart._isThought + ) + ) default: fatalError() } @@ -121,7 +144,26 @@ extension ModelContent: Codable { } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ModelContent.InternalPart: Codable { +extension InternalPart: Codable { + enum CodingKeys: String, CodingKey { + case isThought = "thought" + } + + public func encode(to encoder: Encoder) throws { + try data.encode(to: encoder) + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encodeIfPresent(isThought, forKey: .isThought) + } + + public init(from decoder: Decoder) throws { + data = try OneOfData(from: decoder) + let container = try decoder.container(keyedBy: CodingKeys.self) + isThought = try container.decodeIfPresent(Bool.self, forKey: .isThought) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension InternalPart.OneOfData: Codable { enum CodingKeys: String, CodingKey { case text case inlineData @@ -135,10 +177,10 @@ extension ModelContent.InternalPart: Codable { switch self { case let .text(text): try container.encode(text, forKey: .text) - case let .inlineData(mimetype, bytes): - try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData) - case let .fileData(mimetype: mimetype, url): - try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData) + case let .inlineData(inlineData): + try container.encode(inlineData, forKey: .inlineData) + case let .fileData(fileData): + try container.encode(fileData, forKey: .fileData) case let .functionCall(functionCall): try container.encode(functionCall, forKey: .functionCall) case let .functionResponse(functionResponse): @@ -151,11 +193,9 @@ extension ModelContent.InternalPart: Codable { if values.contains(.text) { self = try .text(values.decode(String.self, forKey: .text)) } else if values.contains(.inlineData) { - let inlineData = try values.decode(InlineData.self, forKey: .inlineData) - self = .inlineData(mimetype: inlineData.mimeType, inlineData.data) + self = try .inlineData(values.decode(InlineData.self, forKey: .inlineData)) } else if values.contains(.fileData) { - let fileData = try values.decode(FileData.self, forKey: .fileData) - self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI) + self = try .fileData(values.decode(FileData.self, forKey: .fileData)) } else if values.contains(.functionCall) { self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) } else if values.contains(.functionResponse) { diff --git a/FirebaseAI/Sources/Types/Internal/InternalPart.swift b/FirebaseAI/Sources/Types/Internal/InternalPart.swift index d543fb80f38..062ae8aa93d 100644 --- a/FirebaseAI/Sources/Types/Internal/InternalPart.swift +++ b/FirebaseAI/Sources/Types/Internal/InternalPart.swift @@ -67,6 +67,8 @@ struct FunctionResponse: Codable, Equatable, Sendable { struct ErrorPart: Part, Error { let error: Error + let isThought = false + init(_ error: Error) { self.error = error } diff --git a/FirebaseAI/Sources/Types/Public/Part.swift b/FirebaseAI/Sources/Types/Public/Part.swift index 4890b725f4d..4fea490a7bf 100644 --- a/FirebaseAI/Sources/Types/Public/Part.swift +++ b/FirebaseAI/Sources/Types/Public/Part.swift @@ -18,7 +18,9 @@ import Foundation /// /// Within a single value of ``Part``, different data types may not mix. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {} +public protocol Part: PartsRepresentable, Codable, Sendable, Equatable { + var isThought: Bool { get } +} /// A text part containing a string value. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) @@ -26,8 +28,17 @@ public struct TextPart: Part { /// Text value. public let text: String + public var isThought: Bool { _isThought ?? false } + + let _isThought: Bool? + public init(_ text: String) { + self.init(text, isThought: nil) + } + + init(_ text: String, isThought: Bool?) { self.text = text + _isThought = isThought } } @@ -45,6 +56,7 @@ public struct TextPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct InlineDataPart: Part { let inlineData: InlineData + let _isThought: Bool? /// The data provided in the inline data part. public var data: Data { inlineData.data } @@ -52,6 +64,8 @@ public struct InlineDataPart: Part { /// The IANA standard MIME type of the data. public var mimeType: String { inlineData.mimeType } + public var isThought: Bool { _isThought ?? false } + /// Creates an inline data part from data and a MIME type. /// /// > Important: Supported input types depend on the model on the model being used; see [input @@ -67,11 +81,12 @@ public struct InlineDataPart: Part { /// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for /// supported values. public init(data: Data, mimeType: String) { - self.init(InlineData(data: data, mimeType: mimeType)) + self.init(InlineData(data: data, mimeType: mimeType), isThought: nil) } - init(_ inlineData: InlineData) { + init(_ inlineData: InlineData, isThought: Bool?) { self.inlineData = inlineData + _isThought = isThought } } @@ -79,9 +94,11 @@ public struct InlineDataPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FileDataPart: Part { let fileData: FileData + let _isThought: Bool? public var uri: String { fileData.fileURI } public var mimeType: String { fileData.mimeType } + public var isThought: Bool { _isThought ?? false } /// Constructs a new file data part. /// @@ -93,11 +110,12 @@ public struct FileDataPart: Part { /// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for /// supported values. public init(uri: String, mimeType: String) { - self.init(FileData(fileURI: uri, mimeType: mimeType)) + self.init(FileData(fileURI: uri, mimeType: mimeType), isThought: nil) } - init(_ fileData: FileData) { + init(_ fileData: FileData, isThought: Bool?) { self.fileData = fileData + _isThought = isThought } } @@ -105,6 +123,7 @@ public struct FileDataPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionCallPart: Part { let functionCall: FunctionCall + let _isThought: Bool? /// The name of the function to call. public var name: String { functionCall.name } @@ -112,6 +131,8 @@ public struct FunctionCallPart: Part { /// The function parameters and values. public var args: JSONObject { functionCall.args } + public var isThought: Bool { _isThought ?? false } + /// Constructs a new function call part. /// /// > Note: A `FunctionCallPart` is typically received from the model, rather than created @@ -121,11 +142,12 @@ public struct FunctionCallPart: Part { /// - name: The name of the function to call. /// - args: The function parameters and values. public init(name: String, args: JSONObject) { - self.init(FunctionCall(name: name, args: args)) + self.init(FunctionCall(name: name, args: args), isThought: nil) } - init(_ functionCall: FunctionCall) { + init(_ functionCall: FunctionCall, isThought: Bool?) { self.functionCall = functionCall + _isThought = isThought } } @@ -137,6 +159,7 @@ public struct FunctionCallPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionResponsePart: Part { let functionResponse: FunctionResponse + let _isThought: Bool? /// The name of the function that was called. public var name: String { functionResponse.name } @@ -144,16 +167,19 @@ public struct FunctionResponsePart: Part { /// The function's response or return value. public var response: JSONObject { functionResponse.response } + public var isThought: Bool { _isThought ?? false } + /// Constructs a new `FunctionResponse`. /// /// - Parameters: /// - name: The name of the function that was called. /// - response: The function's response. public init(name: String, response: JSONObject) { - self.init(FunctionResponse(name: name, response: response)) + self.init(FunctionResponse(name: name, response: response), isThought: nil) } - init(_ functionResponse: FunctionResponse) { + init(_ functionResponse: FunctionResponse, isThought: Bool?) { self.functionResponse = functionResponse + _isThought = isThought } } diff --git a/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift b/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift index c0e8f31465b..c9e2cc43f4e 100644 --- a/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift +++ b/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift @@ -37,12 +37,15 @@ public struct ThinkingConfig: Sendable { /// feature or if the specified budget is not within the model's supported range. let thinkingBudget: Int? + let includeThoughts: Bool? + /// Initializes a new `ThinkingConfig`. /// /// - Parameters: /// - thinkingBudget: The maximum number of tokens to be used for the model's thinking process. - public init(thinkingBudget: Int? = nil) { + public init(thinkingBudget: Int? = nil, includeThoughts: Bool? = nil) { self.thinkingBudget = thinkingBudget + self.includeThoughts = includeThoughts } } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift index da77cea9df7..0092efe33ce 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift @@ -134,47 +134,83 @@ struct GenerateContentIntegrationTests { #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount) } - @Test(arguments: [ - (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2_5_Flash, 24576), - (InstanceConfig.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, 128), - (InstanceConfig.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, 32768), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Flash, 24576), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Pro, 128), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Pro, 32768), - (InstanceConfig.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, 24576), - ]) + @Test( + arguments: [ + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 24576)), + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 128)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 32768)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: 32768, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 24576)), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 128)), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 32768)), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: 32768, includeThoughts: true + )), + (.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + ( + .googleAI_v1beta_freeTier, + ModelNames.gemini2_5_Flash, + ThinkingConfig(thinkingBudget: 24576) + ), + (.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 0 + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576 + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + ] as [(InstanceConfig, String, ThinkingConfig)] + ) func generateContentThinking(_ config: InstanceConfig, modelName: String, - thinkingBudget: Int) async throws { + thinkingConfig: ThinkingConfig) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( modelName: modelName, generationConfig: GenerationConfig( temperature: 0.0, topP: 0.0, topK: 1, - thinkingConfig: ThinkingConfig(thinkingBudget: thinkingBudget) + thinkingConfig: thinkingConfig ), safetySettings: safetySettings ) + let chat = model.startChat() let prompt = "Where is Google headquarters located? Answer with the city name only." - let response = try await model.generateContent(prompt) + let response = try await chat.sendMessage(prompt) let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) #expect(text == "Mountain View") + let candidate = try #require(response.candidates.first) + let thoughtParts = candidate.content.parts.compactMap { $0.isThought ? $0 : nil } + #expect(thoughtParts.isEmpty != thinkingConfig.includeThoughts) + let usageMetadata = try #require(response.usageMetadata) #expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy)) #expect(usageMetadata.promptTokensDetails.count == 1) let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first) #expect(promptTokensDetails.modality == .text) #expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount) - if thinkingBudget == 0 { - #expect(usageMetadata.thoughtsTokenCount == 0) - } else { + if let thinkingBudget = thinkingConfig.thinkingBudget, thinkingBudget > 0 { + #expect(usageMetadata.thoughtsTokenCount > 0) #expect(usageMetadata.thoughtsTokenCount <= thinkingBudget) + } else { + #expect(usageMetadata.thoughtsTokenCount == 0) } #expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy)) // The `candidatesTokensDetails` field is erroneously omitted when using the Google AI (Gemini