Skip to content

Commit 08d0f44

Browse files
authored
[Firebase AI] Add workaround for invalid SafetyRatings in response (#14817)
1 parent 1a5a42f commit 08d0f44

File tree

5 files changed

+111
-12
lines changed

5 files changed

+111
-12
lines changed

FirebaseAI/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Unreleased
22
- [fixed] Fixed `ModalityTokenCount` decoding when the `tokenCount` field is
33
omitted; this occurs when the count is 0. (#14745)
4+
- [fixed] Fixed `Candidate` decoding when `SafetyRating` values are missing a
5+
category or probability; this may occur when using `gemini-2.0-flash-exp` for
6+
image generation. (#14817)
47

58
# 11.12.0
69
- [added] **Public Preview**: Added support for specifying response modalities

FirebaseAI/Sources/GenerateContentResponse.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,14 @@ extension Candidate: Decodable {
381381
}
382382

383383
if let safetyRatings = try container.decodeIfPresent(
384-
[SafetyRating].self,
385-
forKey: .safetyRatings
384+
[SafetyRating].self, forKey: .safetyRatings
386385
) {
387-
self.safetyRatings = safetyRatings
386+
self.safetyRatings = safetyRatings.filter {
387+
// Due to a bug in the backend, the SDK may receive invalid `SafetyRating` values that do
388+
// not include a category or probability; these are filtered out of the safety ratings.
389+
$0.category != HarmCategory.unspecified
390+
&& $0.probability != SafetyRating.HarmProbability.unspecified
391+
}
388392
} else {
389393
safetyRatings = []
390394
}

FirebaseAI/Sources/Safety.swift

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
7878
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
7979
public struct HarmProbability: DecodableProtoEnum, Hashable, Sendable {
8080
enum Kind: String {
81+
case unspecified = "HARM_PROBABILITY_UNSPECIFIED"
8182
case negligible = "NEGLIGIBLE"
8283
case low = "LOW"
8384
case medium = "MEDIUM"
8485
case high = "HIGH"
8586
}
8687

88+
/// Internal-only; harm probability is unknown or unspecified by the backend.
89+
static let unspecified = HarmProbability(kind: .unspecified)
90+
8791
/// The probability is zero or close to zero.
8892
///
8993
/// For benign content, the probability across all categories will be this value.
@@ -114,12 +118,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
114118
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
115119
public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable {
116120
enum Kind: String {
121+
case unspecified = "HARM_SEVERITY_UNSPECIFIED"
117122
case negligible = "HARM_SEVERITY_NEGLIGIBLE"
118123
case low = "HARM_SEVERITY_LOW"
119124
case medium = "HARM_SEVERITY_MEDIUM"
120125
case high = "HARM_SEVERITY_HIGH"
121126
}
122127

128+
/// Internal-only; harm severity is unknown or unspecified by the backend.
129+
static let unspecified: HarmSeverity = .init(kind: .unspecified)
130+
123131
/// Negligible level of harm severity.
124132
public static let negligible = HarmSeverity(kind: .negligible)
125133

@@ -234,13 +242,17 @@ public struct SafetySetting: Sendable {
234242
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
235243
public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
236244
enum Kind: String {
245+
case unspecified = "HARM_CATEGORY_UNSPECIFIED"
237246
case harassment = "HARM_CATEGORY_HARASSMENT"
238247
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
239248
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
240249
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
241250
case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
242251
}
243252

253+
/// Internal-only; harm category is unknown or unspecified by the backend.
254+
static let unspecified = HarmCategory(kind: .unspecified)
255+
244256
/// Harassment content.
245257
public static let harassment = HarmCategory(kind: .harassment)
246258

@@ -281,13 +293,14 @@ extension SafetyRating: Decodable {
281293

282294
public init(from decoder: any Decoder) throws {
283295
let container = try decoder.container(keyedBy: CodingKeys.self)
284-
category = try container.decode(HarmCategory.self, forKey: .category)
285-
probability = try container.decode(HarmProbability.self, forKey: .probability)
296+
category = try container.decodeIfPresent(HarmCategory.self, forKey: .category) ?? .unspecified
297+
probability = try container.decodeIfPresent(
298+
HarmProbability.self, forKey: .probability
299+
) ?? .unspecified
286300

287-
// The following 3 fields are only omitted in our test data.
301+
// The following 3 fields are only provided when using the Vertex AI backend (not Google AI).
288302
probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0
289-
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ??
290-
HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED")
303+
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ?? .unspecified
291304
severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0
292305

293306
// The blocked field is only included when true.

FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ struct GenerateContentIntegrationTests {
115115
}
116116

117117
@Test(arguments: [
118-
// TODO(andrewheard): Vertex AI configs temporarily disabled to due empty SafetyRatings bug.
119-
// InstanceConfig.vertexV1,
120-
// InstanceConfig.vertexV1Beta,
118+
InstanceConfig.vertexAI_v1,
119+
InstanceConfig.vertexAI_v1beta,
121120
InstanceConfig.googleAI_v1beta,
122121
InstanceConfig.googleAI_v1beta_staging,
123122
InstanceConfig.googleAI_v1beta_freeTier_bypassProxy,

FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,41 @@ final class GenerativeModelVertexAITests: XCTestCase {
5656
blocked: false
5757
),
5858
].sorted()
59+
let safetyRatingsInvalidIgnored = [
60+
SafetyRating(
61+
category: .hateSpeech,
62+
probability: .negligible,
63+
probabilityScore: 0.00039444832,
64+
severity: .negligible,
65+
severityScore: 0.0,
66+
blocked: false
67+
),
68+
SafetyRating(
69+
category: .dangerousContent,
70+
probability: .negligible,
71+
probabilityScore: 0.0010654529,
72+
severity: .negligible,
73+
severityScore: 0.0049325973,
74+
blocked: false
75+
),
76+
SafetyRating(
77+
category: .harassment,
78+
probability: .negligible,
79+
probabilityScore: 0.00026658305,
80+
severity: .negligible,
81+
severityScore: 0.0,
82+
blocked: false
83+
),
84+
SafetyRating(
85+
category: .sexuallyExplicit,
86+
probability: .negligible,
87+
probabilityScore: 0.0013701695,
88+
severity: .negligible,
89+
severityScore: 0.07626295,
90+
blocked: false
91+
),
92+
// Ignored Invalid Safety Ratings: {},{},{},{}
93+
].sorted()
5994
let testModelName = "test-model"
6095
let testModelResourceName =
6196
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
@@ -399,6 +434,26 @@ final class GenerativeModelVertexAITests: XCTestCase {
399434
XCTAssertEqual(text, "The sum of [1, 2, 3] is")
400435
}
401436

437+
func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
438+
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
439+
forResource: "unary-success-image-invalid-safety-ratings",
440+
withExtension: "json",
441+
subdirectory: vertexSubdirectory
442+
)
443+
444+
let response = try await model.generateContent(testPrompt)
445+
446+
XCTAssertEqual(response.candidates.count, 1)
447+
let candidate = try XCTUnwrap(response.candidates.first)
448+
XCTAssertEqual(candidate.content.parts.count, 1)
449+
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
450+
let inlineDataParts = response.inlineDataParts
451+
XCTAssertEqual(inlineDataParts.count, 1)
452+
let imagePart = try XCTUnwrap(inlineDataParts.first)
453+
XCTAssertEqual(imagePart.mimeType, "image/png")
454+
XCTAssertGreaterThan(imagePart.data.count, 0)
455+
}
456+
402457
func testGenerateContent_appCheck_validToken() async throws {
403458
let appCheckToken = "test-valid-token"
404459
model = GenerativeModel(
@@ -1118,7 +1173,7 @@ final class GenerativeModelVertexAITests: XCTestCase {
11181173
responses += 1
11191174
}
11201175

1121-
XCTAssertEqual(responses, 6)
1176+
XCTAssertEqual(responses, 4)
11221177
}
11231178

11241179
func testGenerateContentStream_successBasicReplyShort() async throws {
@@ -1220,6 +1275,31 @@ final class GenerativeModelVertexAITests: XCTestCase {
12201275
XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
12211276
}
12221277

1278+
func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
1279+
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
1280+
forResource: "streaming-success-image-invalid-safety-ratings",
1281+
withExtension: "txt",
1282+
subdirectory: vertexSubdirectory
1283+
)
1284+
1285+
let stream = try model.generateContentStream(testPrompt)
1286+
var responses = [GenerateContentResponse]()
1287+
for try await content in stream {
1288+
responses.append(content)
1289+
}
1290+
1291+
let response = try XCTUnwrap(responses.first)
1292+
XCTAssertEqual(response.candidates.count, 1)
1293+
let candidate = try XCTUnwrap(response.candidates.first)
1294+
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
1295+
XCTAssertEqual(candidate.content.parts.count, 1)
1296+
let inlineDataParts = response.inlineDataParts
1297+
XCTAssertEqual(inlineDataParts.count, 1)
1298+
let imagePart = try XCTUnwrap(inlineDataParts.first)
1299+
XCTAssertEqual(imagePart.mimeType, "image/png")
1300+
XCTAssertGreaterThan(imagePart.data.count, 0)
1301+
}
1302+
12231303
func testGenerateContentStream_appCheck_validToken() async throws {
12241304
let appCheckToken = "test-valid-token"
12251305
model = GenerativeModel(

0 commit comments

Comments
 (0)