From c1ecb0f7a09f23dd0e0ab1e9faa7e8bc643e0552 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 6 May 2025 17:30:56 -0400 Subject: [PATCH 1/2] [Firebase AI] Add support for Gemma models with Developer API --- FirebaseAI/Sources/FirebaseAI.swift | 3 +- FirebaseAI/Sources/GenerativeModel.swift | 3 + FirebaseAI/Sources/ModelContent.swift | 6 ++ .../Tests/TestApp/Sources/Constants.swift | 1 + .../GenerateContentIntegrationTests.swift | 92 +++++++++++++------ 5 files changed, 74 insertions(+), 31 deletions(-) diff --git a/FirebaseAI/Sources/FirebaseAI.swift b/FirebaseAI/Sources/FirebaseAI.swift index b7e3ad2e893..48f7183d4e6 100644 --- a/FirebaseAI/Sources/FirebaseAI.swift +++ b/FirebaseAI/Sources/FirebaseAI.swift @@ -72,7 +72,8 @@ public final class FirebaseAI: Sendable { systemInstruction: ModelContent? = nil, requestOptions: RequestOptions = RequestOptions()) -> GenerativeModel { - if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) { + if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) + && !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) { AILog.warning(code: .unsupportedGeminiModel, """ Unsupported Gemini model "\(modelName)"; see \ https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names. diff --git a/FirebaseAI/Sources/GenerativeModel.swift b/FirebaseAI/Sources/GenerativeModel.swift index defe01c4665..8d3f5e043a7 100644 --- a/FirebaseAI/Sources/GenerativeModel.swift +++ b/FirebaseAI/Sources/GenerativeModel.swift @@ -23,6 +23,9 @@ public final class GenerativeModel: Sendable { /// Model name prefix to identify Gemini models. static let geminiModelNamePrefix = "gemini-" + /// Model name prefix to identify Gemma models. + static let gemmaModelNamePrefix = "gemma-" + /// The name of the model, for example "gemini-2.0-flash". let modelName: String diff --git a/FirebaseAI/Sources/ModelContent.swift b/FirebaseAI/Sources/ModelContent.swift index ba87736e648..7d82bd76445 100644 --- a/FirebaseAI/Sources/ModelContent.swift +++ b/FirebaseAI/Sources/ModelContent.swift @@ -112,6 +112,12 @@ extension ModelContent: Codable { case role case internalParts = "parts" } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + role = try container.decodeIfPresent(String.self, forKey: .role) + internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? [] + } } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) diff --git a/FirebaseAI/Tests/TestApp/Sources/Constants.swift b/FirebaseAI/Tests/TestApp/Sources/Constants.swift index 3a731813704..ff1e9bb0250 100644 --- a/FirebaseAI/Tests/TestApp/Sources/Constants.swift +++ b/FirebaseAI/Tests/TestApp/Sources/Constants.swift @@ -24,4 +24,5 @@ public enum ModelNames { public static let gemini2Flash = "gemini-2.0-flash-001" public static let gemini2FlashLite = "gemini-2.0-flash-lite-001" public static let gemini2FlashExperimental = "gemini-2.0-flash-exp" + public static let gemma3_27B = "gemma-3-27b-it" } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift index 0a9c9898291..ecb443b503e 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift @@ -47,12 +47,24 @@ struct GenerateContentIntegrationTests { storage = Storage.storage() } - @Test(arguments: InstanceConfig.allConfigs) - func generateContent(_ config: InstanceConfig) async throws { + @Test(arguments: [ + (InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B), + (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B), + (InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B), + ]) + func generateContent(_ config: InstanceConfig, modelName: String) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2FlashLite, + modelName: modelName, generationConfig: generationConfig, - safetySettings: safetySettings + safetySettings: safetySettings, ) let prompt = "Where is Google headquarters located? Answer with the city name only." @@ -62,17 +74,22 @@ struct GenerateContentIntegrationTests { #expect(text == "Mountain View") let usageMetadata = try #require(response.usageMetadata) - #expect(usageMetadata.promptTokenCount == 13) + #expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy)) #expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy)) #expect(usageMetadata.totalTokenCount.isEqual(to: 16, accuracy: tokenCountAccuracy)) #expect(usageMetadata.promptTokensDetails.count == 1) let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first) #expect(promptTokensDetails.modality == .text) #expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount) - #expect(usageMetadata.candidatesTokensDetails.count == 1) - let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first) - #expect(candidatesTokensDetails.modality == .text) - #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount) + // The field `candidatesTokensDetails` is not included when using Gemma models. + if modelName == ModelNames.gemma3_27B { + #expect(usageMetadata.candidatesTokensDetails.isEmpty) + } else { + #expect(usageMetadata.candidatesTokensDetails.count == 1) + let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first) + #expect(candidatesTokensDetails.modality == .text) + #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount) + } } @Test( @@ -168,24 +185,35 @@ struct GenerateContentIntegrationTests { // MARK: Streaming Tests - @Test(arguments: InstanceConfig.allConfigs) - func generateContentStream(_ config: InstanceConfig) async throws { - let expectedText = """ - 1. Mercury - 2. Venus - 3. Earth - 4. Mars - 5. Jupiter - 6. Saturn - 7. Uranus - 8. Neptune - """ + @Test(arguments: [ + (InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite), + (InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B), + (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B), + (InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite), + (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B), + ]) + func generateContentStream(_ config: InstanceConfig, modelName: String) async throws { + let expectedResponse = [ + "Mercury", "Venus", "Earth", "Mars", "Jupiter", "Saturn", "Uranus", "Neptune", + ] let prompt = """ - What are the names of the planets in the solar system, ordered from closest to furthest from - the sun? Answer with a Markdown numbered list of the names and no other text. + Generate a JSON array of strings. The array must contain the names of the planets in Earth's \ + solar system, ordered from closest to furthest from the Sun. + + Constraints: + - Output MUST be only the JSON array. + - Do NOT include any introductory or explanatory text. + - Do NOT wrap the JSON in Markdown code blocks (e.g., ```json ... ``` or ``` ... ```). + - The response must start with '[' and end with ']'. """ let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2FlashLite, + modelName: modelName, generationConfig: generationConfig, safetySettings: safetySettings ) @@ -194,7 +222,13 @@ struct GenerateContentIntegrationTests { let stream = try chat.sendMessageStream(prompt) var textValues = [String]() for try await value in stream { - try textValues.append(#require(value.text)) + if let text = value.text { + textValues.append(text) + } else if let finishReason = value.candidates.first?.finishReason { + #expect(finishReason == .stop) + } else { + Issue.record("Expected a candidate with a `TextPart` or a `finishReason`; got \(value).") + } } let userHistory = try #require(chat.history.first) @@ -206,11 +240,9 @@ struct GenerateContentIntegrationTests { #expect(modelHistory.role == "model") #expect(modelHistory.parts.count == 1) let modelTextPart = try #require(modelHistory.parts.first as? TextPart) - let modelText = modelTextPart.text.trimmingCharacters(in: .whitespacesAndNewlines) - #expect(modelText == expectedText) - #expect(textValues.count > 1) - let text = textValues.joined().trimmingCharacters(in: .whitespacesAndNewlines) - #expect(text == expectedText) + let modelJSONData = try #require(modelTextPart.text.data(using: .utf8)) + let response = try JSONDecoder().decode([String].self, from: modelJSONData) + #expect(response == expectedResponse) } // MARK: - App Check Tests From 01fc72b8151424218eceb39fe7e872cf55924c58 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 6 May 2025 18:41:52 -0400 Subject: [PATCH 2/2] Fix `Candidate` decoding error handling --- .../Sources/GenerateContentResponse.swift | 17 ++++++++++------- .../Unit/GenerativeModelVertexAITests.swift | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/FirebaseAI/Sources/GenerateContentResponse.swift b/FirebaseAI/Sources/GenerateContentResponse.swift index e654389ce82..6d4ba6932ec 100644 --- a/FirebaseAI/Sources/GenerateContentResponse.swift +++ b/FirebaseAI/Sources/GenerateContentResponse.swift @@ -371,13 +371,7 @@ extension Candidate: Decodable { content = ModelContent(parts: []) } } catch { - // Check if `content` can be decoded as an empty dictionary to detect the `"content": {}` bug. - if let content = try? container.decode([String: String].self, forKey: .content), - content.isEmpty { - throw InvalidCandidateError.emptyContent(underlyingError: error) - } else { - throw InvalidCandidateError.malformedContent(underlyingError: error) - } + throw InvalidCandidateError.malformedContent(underlyingError: error) } if let safetyRatings = try container.decodeIfPresent( @@ -395,6 +389,15 @@ extension Candidate: Decodable { finishReason = try container.decodeIfPresent(FinishReason.self, forKey: .finishReason) + // The `content` may only be empty if a `finishReason` is included; if neither are included in + // the response then this is likely the `"content": {}` bug. + guard !content.parts.isEmpty || finishReason != nil else { + throw InvalidCandidateError.emptyContent(underlyingError: DecodingError.dataCorrupted(.init( + codingPath: [CodingKeys.content, CodingKeys.finishReason], + debugDescription: "Invalid Candidate: empty content and no finish reason" + ))) + } + citationMetadata = try container.decodeIfPresent( CitationMetadata.self, forKey: .citationMetadata diff --git a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift index 930f4efd987..f1092a4c4f6 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift @@ -918,6 +918,9 @@ final class GenerativeModelVertexAITests: XCTestCase { func testGenerateContent_failure_malformedContent() async throws { MockURLProtocol .requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + // Note: Although this file does not contain `parts` in `content`, it is not actually + // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to + // the proto API. Therefore, this test checks for the `emptyContent` error instead. forResource: "unary-failure-malformed-content", withExtension: "json", subdirectory: vertexSubdirectory @@ -939,13 +942,13 @@ final class GenerativeModelVertexAITests: XCTestCase { return } let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError) - guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else { - XCTFail("Not a malformed content error: \(invalidCandidateError)") + guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else { + XCTFail("Not an empty content error: \(invalidCandidateError)") return } _ = try XCTUnwrap( - malformedContentUnderlyingError as? DecodingError, - "Not a decoding error: \(malformedContentUnderlyingError)" + emptyContentUnderlyingError as? DecodingError, + "Not a decoding error: \(emptyContentUnderlyingError)" ) } @@ -1446,6 +1449,9 @@ final class GenerativeModelVertexAITests: XCTestCase { func testGenerateContentStream_malformedContent() async throws { MockURLProtocol .requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + // Note: Although this file does not contain `parts` in `content`, it is not actually + // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to + // the proto API. Therefore, this test checks for the `emptyContent` error instead. forResource: "streaming-failure-malformed-content", withExtension: "txt", subdirectory: vertexSubdirectory @@ -1457,8 +1463,8 @@ final class GenerativeModelVertexAITests: XCTestCase { XCTFail("Unexpected content in stream: \(content)") } } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) { - guard case let .malformedContent(contentError) = underlyingError else { - XCTFail("Not a malformed content error: \(underlyingError)") + guard case let .emptyContent(contentError) = underlyingError else { + XCTFail("Not an empty content error: \(underlyingError)") return }