Skip to content

[Firebase AI] Add support for Gemma models with Developer API #14823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion FirebaseAI/Sources/FirebaseAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ public final class FirebaseAI: Sendable {
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) {
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix)
&& !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) {
AILog.warning(code: .unsupportedGeminiModel, """
Unsupported Gemini model "\(modelName)"; see \
https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.
Expand Down
17 changes: 10 additions & 7 deletions FirebaseAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,7 @@ extension Candidate: Decodable {
content = ModelContent(parts: [])
}
} catch {
// Check if `content` can be decoded as an empty dictionary to detect the `"content": {}` bug.
if let content = try? container.decode([String: String].self, forKey: .content),
content.isEmpty {
throw InvalidCandidateError.emptyContent(underlyingError: error)
} else {
throw InvalidCandidateError.malformedContent(underlyingError: error)
}
throw InvalidCandidateError.malformedContent(underlyingError: error)
}

if let safetyRatings = try container.decodeIfPresent(
Expand All @@ -395,6 +389,15 @@ extension Candidate: Decodable {

finishReason = try container.decodeIfPresent(FinishReason.self, forKey: .finishReason)

// The `content` may only be empty if a `finishReason` is included; if neither are included in
// the response then this is likely the `"content": {}` bug.
guard !content.parts.isEmpty || finishReason != nil else {
throw InvalidCandidateError.emptyContent(underlyingError: DecodingError.dataCorrupted(.init(
codingPath: [CodingKeys.content, CodingKeys.finishReason],
debugDescription: "Invalid Candidate: empty content and no finish reason"
)))
}

citationMetadata = try container.decodeIfPresent(
CitationMetadata.self,
forKey: .citationMetadata
Expand Down
3 changes: 3 additions & 0 deletions FirebaseAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public final class GenerativeModel: Sendable {
/// Model name prefix to identify Gemini models.
static let geminiModelNamePrefix = "gemini-"

/// Model name prefix to identify Gemma models.
static let gemmaModelNamePrefix = "gemma-"

/// The name of the model, for example "gemini-2.0-flash".
let modelName: String

Expand Down
6 changes: 6 additions & 0 deletions FirebaseAI/Sources/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ extension ModelContent: Codable {
case role
case internalParts = "parts"
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
role = try container.decodeIfPresent(String.self, forKey: .role)
internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? []
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
1 change: 1 addition & 0 deletions FirebaseAI/Tests/TestApp/Sources/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ public enum ModelNames {
public static let gemini2Flash = "gemini-2.0-flash-001"
public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
public static let gemini2FlashExperimental = "gemini-2.0-flash-exp"
public static let gemma3_27B = "gemma-3-27b-it"
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,24 @@ struct GenerateContentIntegrationTests {
storage = Storage.storage()
}

@Test(arguments: InstanceConfig.allConfigs)
func generateContent(_ config: InstanceConfig) async throws {
@Test(arguments: [
(InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
(InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
])
func generateContent(_ config: InstanceConfig, modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
modelName: modelName,
generationConfig: generationConfig,
safetySettings: safetySettings
safetySettings: safetySettings,
)
let prompt = "Where is Google headquarters located? Answer with the city name only."

Expand All @@ -62,17 +74,22 @@ struct GenerateContentIntegrationTests {
#expect(text == "Mountain View")

let usageMetadata = try #require(response.usageMetadata)
#expect(usageMetadata.promptTokenCount == 13)
#expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy))
#expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy))
#expect(usageMetadata.totalTokenCount.isEqual(to: 16, accuracy: tokenCountAccuracy))
#expect(usageMetadata.promptTokensDetails.count == 1)
let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first)
#expect(promptTokensDetails.modality == .text)
#expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount)
#expect(usageMetadata.candidatesTokensDetails.count == 1)
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
#expect(candidatesTokensDetails.modality == .text)
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
// The field `candidatesTokensDetails` is not included when using Gemma models.
if modelName == ModelNames.gemma3_27B {
#expect(usageMetadata.candidatesTokensDetails.isEmpty)
} else {
#expect(usageMetadata.candidatesTokensDetails.count == 1)
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
#expect(candidatesTokensDetails.modality == .text)
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
}
}

@Test(
Expand Down Expand Up @@ -168,24 +185,35 @@ struct GenerateContentIntegrationTests {

// MARK: Streaming Tests

@Test(arguments: InstanceConfig.allConfigs)
func generateContentStream(_ config: InstanceConfig) async throws {
let expectedText = """
1. Mercury
2. Venus
3. Earth
4. Mars
5. Jupiter
6. Saturn
7. Uranus
8. Neptune
"""
@Test(arguments: [
(InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
(InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
(InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
(InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
])
func generateContentStream(_ config: InstanceConfig, modelName: String) async throws {
let expectedResponse = [
"Mercury", "Venus", "Earth", "Mars", "Jupiter", "Saturn", "Uranus", "Neptune",
]
let prompt = """
What are the names of the planets in the solar system, ordered from closest to furthest from
the sun? Answer with a Markdown numbered list of the names and no other text.
Generate a JSON array of strings. The array must contain the names of the planets in Earth's \
solar system, ordered from closest to furthest from the Sun.

Constraints:
- Output MUST be only the JSON array.
- Do NOT include any introductory or explanatory text.
- Do NOT wrap the JSON in Markdown code blocks (e.g., ```json ... ``` or ``` ... ```).
- The response must start with '[' and end with ']'.
"""
let model = FirebaseAI.componentInstance(config).generativeModel(
modelName: ModelNames.gemini2FlashLite,
modelName: modelName,
generationConfig: generationConfig,
safetySettings: safetySettings
)
Expand All @@ -194,7 +222,13 @@ struct GenerateContentIntegrationTests {
let stream = try chat.sendMessageStream(prompt)
var textValues = [String]()
for try await value in stream {
try textValues.append(#require(value.text))
if let text = value.text {
textValues.append(text)
} else if let finishReason = value.candidates.first?.finishReason {
#expect(finishReason == .stop)
} else {
Issue.record("Expected a candidate with a `TextPart` or a `finishReason`; got \(value).")
}
}

let userHistory = try #require(chat.history.first)
Expand All @@ -206,11 +240,9 @@ struct GenerateContentIntegrationTests {
#expect(modelHistory.role == "model")
#expect(modelHistory.parts.count == 1)
let modelTextPart = try #require(modelHistory.parts.first as? TextPart)
let modelText = modelTextPart.text.trimmingCharacters(in: .whitespacesAndNewlines)
#expect(modelText == expectedText)
#expect(textValues.count > 1)
let text = textValues.joined().trimmingCharacters(in: .whitespacesAndNewlines)
#expect(text == expectedText)
let modelJSONData = try #require(modelTextPart.text.data(using: .utf8))
let response = try JSONDecoder().decode([String].self, from: modelJSONData)
#expect(response == expectedResponse)
}

// MARK: - App Check Tests
Expand Down
18 changes: 12 additions & 6 deletions FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
func testGenerateContent_failure_malformedContent() async throws {
MockURLProtocol
.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
// Note: Although this file does not contain `parts` in `content`, it is not actually
// malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
// the proto API. Therefore, this test checks for the `emptyContent` error instead.
forResource: "unary-failure-malformed-content",
withExtension: "json",
subdirectory: vertexSubdirectory
Expand All @@ -939,13 +942,13 @@ final class GenerativeModelVertexAITests: XCTestCase {
return
}
let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
XCTFail("Not a malformed content error: \(invalidCandidateError)")
guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
XCTFail("Not an empty content error: \(invalidCandidateError)")
return
}
_ = try XCTUnwrap(
malformedContentUnderlyingError as? DecodingError,
"Not a decoding error: \(malformedContentUnderlyingError)"
emptyContentUnderlyingError as? DecodingError,
"Not a decoding error: \(emptyContentUnderlyingError)"
)
}

Expand Down Expand Up @@ -1446,6 +1449,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
func testGenerateContentStream_malformedContent() async throws {
MockURLProtocol
.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
// Note: Although this file does not contain `parts` in `content`, it is not actually
// malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
// the proto API. Therefore, this test checks for the `emptyContent` error instead.
forResource: "streaming-failure-malformed-content",
withExtension: "txt",
subdirectory: vertexSubdirectory
Expand All @@ -1457,8 +1463,8 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTFail("Unexpected content in stream: \(content)")
}
} catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
guard case let .malformedContent(contentError) = underlyingError else {
XCTFail("Not a malformed content error: \(underlyingError)")
guard case let .emptyContent(contentError) = underlyingError else {
XCTFail("Not an empty content error: \(underlyingError)")
return
}

Expand Down
Loading