Skip to content

[Firebase AI] Add support for thought summaries #15096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions FirebaseAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@ public final class Chat: Sendable {
for aggregate in chunks {
// Loop through all the parts, aggregating the text and adding the images.
for part in aggregate.parts {
switch part {
case let textPart as TextPart:
combinedText += textPart.text
guard !part.isThought else {
continue
}

default:
if let textPart = part as? TextPart {
combinedText += textPart.text
} else {
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
Expand Down
19 changes: 10 additions & 9 deletions FirebaseAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@ public struct GenerateContentResponse: Sendable {
return nil
}
let textValues: [String] = candidate.content.parts.compactMap { part in
switch part {
case let textPart as TextPart:
return textPart.text
default:
guard let textPart = part as? TextPart, !part.isThought else {
return nil
}
return textPart.text
}
guard textValues.count > 0 else {
AILog.error(
Expand All @@ -89,12 +87,10 @@ public struct GenerateContentResponse: Sendable {
return []
}
return candidate.content.parts.compactMap { part in
switch part {
case let functionCallPart as FunctionCallPart:
return functionCallPart
default:
guard let functionCallPart = part as? FunctionCallPart, !part.isThought else {
return nil
}
return functionCallPart
}
}

Expand All @@ -107,7 +103,12 @@ public struct GenerateContentResponse: Sendable {
""")
return []
}
return candidate.content.parts.compactMap { $0 as? InlineDataPart }
return candidate.content.parts.compactMap { part in
guard let inlineDataPart = part as? InlineDataPart, !part.isThought else {
return nil
}
return inlineDataPart
}
}

/// Initializer for SwiftUI previews or tests.
Expand Down
102 changes: 71 additions & 31 deletions FirebaseAI/Sources/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,31 @@ extension [ModelContent] {
}
}

/// A type describing data in media formats interpretable by an AI model. Each generative AI
/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value
/// may comprise multiple heterogeneous ``Part``s.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ModelContent: Equatable, Sendable {
enum InternalPart: Equatable, Sendable {
struct InternalPart: Equatable, Sendable {
enum OneOfData: Equatable, Sendable {
case text(String)
case inlineData(mimetype: String, Data)
case fileData(mimetype: String, uri: String)
case inlineData(InlineData)
case fileData(FileData)
case functionCall(FunctionCall)
case functionResponse(FunctionResponse)
}

let data: OneOfData

let isThought: Bool?

init(_ data: OneOfData, isThought: Bool?) {
self.data = data
self.isThought = isThought
}
}

/// A type describing data in media formats interpretable by an AI model. Each generative AI
/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value
/// may comprise multiple heterogeneous ``Part``s.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ModelContent: Equatable, Sendable {
/// The role of the entity creating the ``ModelContent``. For user-generated client requests,
/// for example, the role is `user`.
public let role: String?
Expand All @@ -52,17 +64,17 @@ public struct ModelContent: Equatable, Sendable {
public var parts: [any Part] {
var convertedParts = [any Part]()
for part in internalParts {
switch part {
switch part.data {
case let .text(text):
convertedParts.append(TextPart(text))
case let .inlineData(mimetype, data):
convertedParts.append(InlineDataPart(data: data, mimeType: mimetype))
case let .fileData(mimetype, uri):
convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype))
convertedParts.append(TextPart(text, isThought: part.isThought))
case let .inlineData(inlineData):
convertedParts.append(InlineDataPart(inlineData, isThought: part.isThought))
case let .fileData(fileData):
convertedParts.append(FileDataPart(fileData, isThought: part.isThought))
case let .functionCall(functionCall):
convertedParts.append(FunctionCallPart(functionCall))
convertedParts.append(FunctionCallPart(functionCall, isThought: part.isThought))
case let .functionResponse(functionResponse):
convertedParts.append(FunctionResponsePart(functionResponse))
convertedParts.append(FunctionResponsePart(functionResponse, isThought: part.isThought))
}
}
return convertedParts
Expand All @@ -78,17 +90,28 @@ public struct ModelContent: Equatable, Sendable {
for part in parts {
switch part {
case let textPart as TextPart:
convertedParts.append(.text(textPart.text))
convertedParts.append(InternalPart(.text(textPart.text), isThought: textPart._isThought))
case let inlineDataPart as InlineDataPart:
let inlineData = inlineDataPart.inlineData
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
convertedParts.append(
InternalPart(.inlineData(inlineDataPart.inlineData), isThought: inlineDataPart._isThought)
)
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
convertedParts.append(
InternalPart(.fileData(fileDataPart.fileData), isThought: fileDataPart._isThought)
)
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
convertedParts.append(
InternalPart(
.functionCall(functionCallPart.functionCall), isThought: functionCallPart._isThought
)
)
case let functionResponsePart as FunctionResponsePart:
convertedParts.append(.functionResponse(functionResponsePart.functionResponse))
convertedParts.append(
InternalPart(
.functionResponse(functionResponsePart.functionResponse),
isThought: functionResponsePart._isThought
)
)
default:
fatalError()
}
Expand Down Expand Up @@ -121,7 +144,26 @@ extension ModelContent: Codable {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ModelContent.InternalPart: Codable {
extension InternalPart: Codable {
enum CodingKeys: String, CodingKey {
case isThought = "thought"
}

public func encode(to encoder: Encoder) throws {
try data.encode(to: encoder)
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encodeIfPresent(isThought, forKey: .isThought)
}

public init(from decoder: Decoder) throws {
data = try OneOfData(from: decoder)
let container = try decoder.container(keyedBy: CodingKeys.self)
isThought = try container.decodeIfPresent(Bool.self, forKey: .isThought)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension InternalPart.OneOfData: Codable {
enum CodingKeys: String, CodingKey {
case text
case inlineData
Expand All @@ -135,10 +177,10 @@ extension ModelContent.InternalPart: Codable {
switch self {
case let .text(text):
try container.encode(text, forKey: .text)
case let .inlineData(mimetype, bytes):
try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData)
case let .fileData(mimetype: mimetype, url):
try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData)
case let .inlineData(inlineData):
try container.encode(inlineData, forKey: .inlineData)
case let .fileData(fileData):
try container.encode(fileData, forKey: .fileData)
case let .functionCall(functionCall):
try container.encode(functionCall, forKey: .functionCall)
case let .functionResponse(functionResponse):
Expand All @@ -151,11 +193,9 @@ extension ModelContent.InternalPart: Codable {
if values.contains(.text) {
self = try .text(values.decode(String.self, forKey: .text))
} else if values.contains(.inlineData) {
let inlineData = try values.decode(InlineData.self, forKey: .inlineData)
self = .inlineData(mimetype: inlineData.mimeType, inlineData.data)
self = try .inlineData(values.decode(InlineData.self, forKey: .inlineData))
} else if values.contains(.fileData) {
let fileData = try values.decode(FileData.self, forKey: .fileData)
self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)
self = try .fileData(values.decode(FileData.self, forKey: .fileData))
} else if values.contains(.functionCall) {
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
} else if values.contains(.functionResponse) {
Expand Down
2 changes: 2 additions & 0 deletions FirebaseAI/Sources/Types/Internal/InternalPart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ struct FunctionResponse: Codable, Equatable, Sendable {
struct ErrorPart: Part, Error {
let error: Error

let isThought = false

init(_ error: Error) {
self.error = error
}
Expand Down
44 changes: 35 additions & 9 deletions FirebaseAI/Sources/Types/Public/Part.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,27 @@ import Foundation
///
/// Within a single value of ``Part``, different data types may not mix.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {}
public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {
var isThought: Bool { get }
}

/// A text part containing a string value.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct TextPart: Part {
/// Text value.
public let text: String

public var isThought: Bool { _isThought ?? false }

let _isThought: Bool?

public init(_ text: String) {
self.init(text, isThought: nil)
}

init(_ text: String, isThought: Bool?) {
self.text = text
_isThought = isThought
}
}

Expand All @@ -45,13 +56,16 @@ public struct TextPart: Part {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct InlineDataPart: Part {
let inlineData: InlineData
let _isThought: Bool?

/// The data provided in the inline data part.
public var data: Data { inlineData.data }

/// The IANA standard MIME type of the data.
public var mimeType: String { inlineData.mimeType }

public var isThought: Bool { _isThought ?? false }

/// Creates an inline data part from data and a MIME type.
///
/// > Important: Supported input types depend on the model on the model being used; see [input
Expand All @@ -67,21 +81,24 @@ public struct InlineDataPart: Part {
/// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for
/// supported values.
public init(data: Data, mimeType: String) {
self.init(InlineData(data: data, mimeType: mimeType))
self.init(InlineData(data: data, mimeType: mimeType), isThought: nil)
}

init(_ inlineData: InlineData) {
init(_ inlineData: InlineData, isThought: Bool?) {
self.inlineData = inlineData
_isThought = isThought
}
}

/// File data stored in Cloud Storage for Firebase, referenced by URI.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FileDataPart: Part {
let fileData: FileData
let _isThought: Bool?

public var uri: String { fileData.fileURI }
public var mimeType: String { fileData.mimeType }
public var isThought: Bool { _isThought ?? false }

/// Constructs a new file data part.
///
Expand All @@ -93,25 +110,29 @@ public struct FileDataPart: Part {
/// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for
/// supported values.
public init(uri: String, mimeType: String) {
self.init(FileData(fileURI: uri, mimeType: mimeType))
self.init(FileData(fileURI: uri, mimeType: mimeType), isThought: nil)
}

init(_ fileData: FileData) {
init(_ fileData: FileData, isThought: Bool?) {
self.fileData = fileData
_isThought = isThought
}
}

/// A predicted function call returned from the model.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionCallPart: Part {
let functionCall: FunctionCall
let _isThought: Bool?

/// The name of the function to call.
public var name: String { functionCall.name }

/// The function parameters and values.
public var args: JSONObject { functionCall.args }

public var isThought: Bool { _isThought ?? false }

/// Constructs a new function call part.
///
/// > Note: A `FunctionCallPart` is typically received from the model, rather than created
Expand All @@ -121,11 +142,12 @@ public struct FunctionCallPart: Part {
/// - name: The name of the function to call.
/// - args: The function parameters and values.
public init(name: String, args: JSONObject) {
self.init(FunctionCall(name: name, args: args))
self.init(FunctionCall(name: name, args: args), isThought: nil)
}

init(_ functionCall: FunctionCall) {
init(_ functionCall: FunctionCall, isThought: Bool?) {
self.functionCall = functionCall
_isThought = isThought
}
}

Expand All @@ -137,23 +159,27 @@ public struct FunctionCallPart: Part {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct FunctionResponsePart: Part {
let functionResponse: FunctionResponse
let _isThought: Bool?

/// The name of the function that was called.
public var name: String { functionResponse.name }

/// The function's response or return value.
public var response: JSONObject { functionResponse.response }

public var isThought: Bool { _isThought ?? false }

/// Constructs a new `FunctionResponse`.
///
/// - Parameters:
/// - name: The name of the function that was called.
/// - response: The function's response.
public init(name: String, response: JSONObject) {
self.init(FunctionResponse(name: name, response: response))
self.init(FunctionResponse(name: name, response: response), isThought: nil)
}

init(_ functionResponse: FunctionResponse) {
init(_ functionResponse: FunctionResponse, isThought: Bool?) {
self.functionResponse = functionResponse
_isThought = isThought
}
}
5 changes: 4 additions & 1 deletion FirebaseAI/Sources/Types/Public/ThinkingConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ public struct ThinkingConfig: Sendable {
/// feature or if the specified budget is not within the model's supported range.
let thinkingBudget: Int?

let includeThoughts: Bool?

/// Initializes a new `ThinkingConfig`.
///
/// - Parameters:
/// - thinkingBudget: The maximum number of tokens to be used for the model's thinking process.
public init(thinkingBudget: Int? = nil) {
public init(thinkingBudget: Int? = nil, includeThoughts: Bool? = nil) {
self.thinkingBudget = thinkingBudget
self.includeThoughts = includeThoughts
}
}

Expand Down
Loading
Loading