Skip to content

Commit 0be8cec

Browse files
authored
[Firebase AI] Add support for Gemma models with Developer API (#14823)
1 parent 4e7e75f commit 0be8cec

File tree

7 files changed

+96
-44
lines changed

7 files changed

+96
-44
lines changed

FirebaseAI/Sources/FirebaseAI.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ public final class FirebaseAI: Sendable {
7272
systemInstruction: ModelContent? = nil,
7373
requestOptions: RequestOptions = RequestOptions())
7474
-> GenerativeModel {
75-
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) {
75+
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix)
76+
&& !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) {
7677
AILog.warning(code: .unsupportedGeminiModel, """
7778
Unsupported Gemini model "\(modelName)"; see \
7879
https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.

FirebaseAI/Sources/GenerateContentResponse.swift

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,7 @@ extension Candidate: Decodable {
371371
content = ModelContent(parts: [])
372372
}
373373
} catch {
374-
// Check if `content` can be decoded as an empty dictionary to detect the `"content": {}` bug.
375-
if let content = try? container.decode([String: String].self, forKey: .content),
376-
content.isEmpty {
377-
throw InvalidCandidateError.emptyContent(underlyingError: error)
378-
} else {
379-
throw InvalidCandidateError.malformedContent(underlyingError: error)
380-
}
374+
throw InvalidCandidateError.malformedContent(underlyingError: error)
381375
}
382376

383377
if let safetyRatings = try container.decodeIfPresent(
@@ -395,6 +389,15 @@ extension Candidate: Decodable {
395389

396390
finishReason = try container.decodeIfPresent(FinishReason.self, forKey: .finishReason)
397391

392+
// The `content` may only be empty if a `finishReason` is included; if neither are included in
393+
// the response then this is likely the `"content": {}` bug.
394+
guard !content.parts.isEmpty || finishReason != nil else {
395+
throw InvalidCandidateError.emptyContent(underlyingError: DecodingError.dataCorrupted(.init(
396+
codingPath: [CodingKeys.content, CodingKeys.finishReason],
397+
debugDescription: "Invalid Candidate: empty content and no finish reason"
398+
)))
399+
}
400+
398401
citationMetadata = try container.decodeIfPresent(
399402
CitationMetadata.self,
400403
forKey: .citationMetadata

FirebaseAI/Sources/GenerativeModel.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ public final class GenerativeModel: Sendable {
2323
/// Model name prefix to identify Gemini models.
2424
static let geminiModelNamePrefix = "gemini-"
2525

26+
/// Model name prefix to identify Gemma models.
27+
static let gemmaModelNamePrefix = "gemma-"
28+
2629
/// The name of the model, for example "gemini-2.0-flash".
2730
let modelName: String
2831

FirebaseAI/Sources/ModelContent.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ extension ModelContent: Codable {
112112
case role
113113
case internalParts = "parts"
114114
}
115+
116+
public init(from decoder: any Decoder) throws {
117+
let container = try decoder.container(keyedBy: CodingKeys.self)
118+
role = try container.decodeIfPresent(String.self, forKey: .role)
119+
internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? []
120+
}
115121
}
116122

117123
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseAI/Tests/TestApp/Sources/Constants.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ public enum ModelNames {
2424
public static let gemini2Flash = "gemini-2.0-flash-001"
2525
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
2626
public static let gemini2FlashExperimental = "gemini-2.0-flash-exp"
27+
public static let gemma3_27B = "gemma-3-27b-it"
2728
}

FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,24 @@ struct GenerateContentIntegrationTests {
4747
storage = Storage.storage()
4848
}
4949

50-
@Test(arguments: InstanceConfig.allConfigs)
51-
func generateContent(_ config: InstanceConfig) async throws {
50+
@Test(arguments: [
51+
(InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
52+
(InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
53+
(InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
54+
(InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
55+
(InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
56+
(InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
57+
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
58+
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
59+
(InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
60+
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
61+
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
62+
])
63+
func generateContent(_ config: InstanceConfig, modelName: String) async throws {
5264
let model = FirebaseAI.componentInstance(config).generativeModel(
53-
modelName: ModelNames.gemini2FlashLite,
65+
modelName: modelName,
5466
generationConfig: generationConfig,
55-
safetySettings: safetySettings
67+
safetySettings: safetySettings,
5668
)
5769
let prompt = "Where is Google headquarters located? Answer with the city name only."
5870

@@ -62,17 +74,22 @@ struct GenerateContentIntegrationTests {
6274
#expect(text == "Mountain View")
6375

6476
let usageMetadata = try #require(response.usageMetadata)
65-
#expect(usageMetadata.promptTokenCount == 13)
77+
#expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy))
6678
#expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy))
6779
#expect(usageMetadata.totalTokenCount.isEqual(to: 16, accuracy: tokenCountAccuracy))
6880
#expect(usageMetadata.promptTokensDetails.count == 1)
6981
let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first)
7082
#expect(promptTokensDetails.modality == .text)
7183
#expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount)
72-
#expect(usageMetadata.candidatesTokensDetails.count == 1)
73-
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
74-
#expect(candidatesTokensDetails.modality == .text)
75-
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
84+
// The field `candidatesTokensDetails` is not included when using Gemma models.
85+
if modelName == ModelNames.gemma3_27B {
86+
#expect(usageMetadata.candidatesTokensDetails.isEmpty)
87+
} else {
88+
#expect(usageMetadata.candidatesTokensDetails.count == 1)
89+
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
90+
#expect(candidatesTokensDetails.modality == .text)
91+
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
92+
}
7693
}
7794

7895
@Test(
@@ -168,24 +185,35 @@ struct GenerateContentIntegrationTests {
168185

169186
// MARK: Streaming Tests
170187

171-
@Test(arguments: InstanceConfig.allConfigs)
172-
func generateContentStream(_ config: InstanceConfig) async throws {
173-
let expectedText = """
174-
1. Mercury
175-
2. Venus
176-
3. Earth
177-
4. Mars
178-
5. Jupiter
179-
6. Saturn
180-
7. Uranus
181-
8. Neptune
182-
"""
188+
@Test(arguments: [
189+
(InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
190+
(InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
191+
(InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
192+
(InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
193+
(InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
194+
(InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
195+
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
196+
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
197+
(InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
198+
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
199+
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
200+
])
201+
func generateContentStream(_ config: InstanceConfig, modelName: String) async throws {
202+
let expectedResponse = [
203+
"Mercury", "Venus", "Earth", "Mars", "Jupiter", "Saturn", "Uranus", "Neptune",
204+
]
183205
let prompt = """
184-
What are the names of the planets in the solar system, ordered from closest to furthest from
185-
the sun? Answer with a Markdown numbered list of the names and no other text.
206+
Generate a JSON array of strings. The array must contain the names of the planets in Earth's \
207+
solar system, ordered from closest to furthest from the Sun.
208+
209+
Constraints:
210+
- Output MUST be only the JSON array.
211+
- Do NOT include any introductory or explanatory text.
212+
- Do NOT wrap the JSON in Markdown code blocks (e.g., ```json ... ``` or ``` ... ```).
213+
- The response must start with '[' and end with ']'.
186214
"""
187215
let model = FirebaseAI.componentInstance(config).generativeModel(
188-
modelName: ModelNames.gemini2FlashLite,
216+
modelName: modelName,
189217
generationConfig: generationConfig,
190218
safetySettings: safetySettings
191219
)
@@ -194,7 +222,13 @@ struct GenerateContentIntegrationTests {
194222
let stream = try chat.sendMessageStream(prompt)
195223
var textValues = [String]()
196224
for try await value in stream {
197-
try textValues.append(#require(value.text))
225+
if let text = value.text {
226+
textValues.append(text)
227+
} else if let finishReason = value.candidates.first?.finishReason {
228+
#expect(finishReason == .stop)
229+
} else {
230+
Issue.record("Expected a candidate with a `TextPart` or a `finishReason`; got \(value).")
231+
}
198232
}
199233

200234
let userHistory = try #require(chat.history.first)
@@ -206,11 +240,9 @@ struct GenerateContentIntegrationTests {
206240
#expect(modelHistory.role == "model")
207241
#expect(modelHistory.parts.count == 1)
208242
let modelTextPart = try #require(modelHistory.parts.first as? TextPart)
209-
let modelText = modelTextPart.text.trimmingCharacters(in: .whitespacesAndNewlines)
210-
#expect(modelText == expectedText)
211-
#expect(textValues.count > 1)
212-
let text = textValues.joined().trimmingCharacters(in: .whitespacesAndNewlines)
213-
#expect(text == expectedText)
243+
let modelJSONData = try #require(modelTextPart.text.data(using: .utf8))
244+
let response = try JSONDecoder().decode([String].self, from: modelJSONData)
245+
#expect(response == expectedResponse)
214246
}
215247

216248
// MARK: - App Check Tests

FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
918918
func testGenerateContent_failure_malformedContent() async throws {
919919
MockURLProtocol
920920
.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
921+
// Note: Although this file does not contain `parts` in `content`, it is not actually
922+
// malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
923+
// the proto API. Therefore, this test checks for the `emptyContent` error instead.
921924
forResource: "unary-failure-malformed-content",
922925
withExtension: "json",
923926
subdirectory: vertexSubdirectory
@@ -939,13 +942,13 @@ final class GenerativeModelVertexAITests: XCTestCase {
939942
return
940943
}
941944
let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
942-
guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
943-
XCTFail("Not a malformed content error: \(invalidCandidateError)")
945+
guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
946+
XCTFail("Not an empty content error: \(invalidCandidateError)")
944947
return
945948
}
946949
_ = try XCTUnwrap(
947-
malformedContentUnderlyingError as? DecodingError,
948-
"Not a decoding error: \(malformedContentUnderlyingError)"
950+
emptyContentUnderlyingError as? DecodingError,
951+
"Not a decoding error: \(emptyContentUnderlyingError)"
949952
)
950953
}
951954

@@ -1446,6 +1449,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
14461449
func testGenerateContentStream_malformedContent() async throws {
14471450
MockURLProtocol
14481451
.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
1452+
// Note: Although this file does not contain `parts` in `content`, it is not actually
1453+
// malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
1454+
// the proto API. Therefore, this test checks for the `emptyContent` error instead.
14491455
forResource: "streaming-failure-malformed-content",
14501456
withExtension: "txt",
14511457
subdirectory: vertexSubdirectory
@@ -1457,8 +1463,8 @@ final class GenerativeModelVertexAITests: XCTestCase {
14571463
XCTFail("Unexpected content in stream: \(content)")
14581464
}
14591465
} catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
1460-
guard case let .malformedContent(contentError) = underlyingError else {
1461-
XCTFail("Not a malformed content error: \(underlyingError)")
1466+
guard case let .emptyContent(contentError) = underlyingError else {
1467+
XCTFail("Not an empty content error: \(underlyingError)")
14621468
return
14631469
}
14641470

0 commit comments

Comments
 (0)