Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ public struct GeminiLanguageModel: LanguageModel {
/// When set to `.enabled`, the model will output valid JSON.
/// When set to `.schema(_:)`, the model will output JSON
/// conforming to the provided schema.
///
/// - Note: When generating a non-`String` ``Generable`` type, the model
/// always uses the generated schema for structured output and ignores
/// this setting.
public var jsonMode: JSONMode?

/// Creates custom generation options for Gemini models.
Expand Down Expand Up @@ -262,10 +266,6 @@ public struct GeminiLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
guard type == String.self else {
fatalError("GeminiLanguageModel only supports generating String content")
}

// Extract effective configuration from custom options or fall back to model defaults
let customOptions = options[custom: GeminiLanguageModel.self]
let effectiveThinking = customOptions?.thinking ?? _thinking
Expand All @@ -287,6 +287,7 @@ public struct GeminiLanguageModel: LanguageModel {
let params = try createGenerateContentParams(
contents: transcript.toGeminiContent(),
tools: geminiTools,
generating: type,
options: options,
thinking: effectiveThinking,
jsonMode: effectiveJsonMode
Expand Down Expand Up @@ -319,9 +320,10 @@ public struct GeminiLanguageModel: LanguageModel {
if !calls.isEmpty {
transcript.append(.toolCalls(Transcript.ToolCalls(calls)))
}
let empty = try emptyResponseContent(for: type)
return LanguageModelSession.Response(
content: "" as! Content,
rawContent: GeneratedContent(""),
content: empty.content,
rawContent: empty.rawContent,
transcriptEntries: ArraySlice(transcript)
)
case .invocations(let invocations):
Expand All @@ -346,9 +348,19 @@ public struct GeminiLanguageModel: LanguageModel {
}
}.joined() ?? ""

if type == String.self {
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
transcriptEntries: ArraySlice(transcript)
)
}

let generatedContent = try GeneratedContent(json: text)
let content = try type.init(generatedContent)
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
content: content,
rawContent: generatedContent,
transcriptEntries: ArraySlice(transcript)
)
}
Expand All @@ -362,10 +374,6 @@ public struct GeminiLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
guard type == String.self else {
fatalError("GeminiLanguageModel only supports generating String content")
}

// Extract effective configuration from custom options or fall back to model defaults
let customOptions = options[custom: GeminiLanguageModel.self]
let effectiveThinking = customOptions?.thinking ?? _thinking
Expand All @@ -390,6 +398,7 @@ public struct GeminiLanguageModel: LanguageModel {
let params = try createGenerateContentParams(
contents: session.transcript.toGeminiContent(),
tools: geminiTools,
generating: type,
options: options,
thinking: effectiveThinking,
jsonMode: effectiveJsonMode
Expand All @@ -416,10 +425,27 @@ public struct GeminiLanguageModel: LanguageModel {
if case .text(let textPart) = part {
accumulatedText += textPart.text

let raw = GeneratedContent(accumulatedText)
let content: Content.PartiallyGenerated = (accumulatedText as! Content)
.asPartiallyGenerated()
continuation.yield(.init(content: content, rawContent: raw))
var raw: GeneratedContent
let content: Content.PartiallyGenerated?

if type == String.self {
raw = GeneratedContent(accumulatedText)
content = (accumulatedText as! Content).asPartiallyGenerated()
} else {
raw =
(try? GeneratedContent(json: accumulatedText))
?? GeneratedContent(accumulatedText)
if let parsed = try? type.init(raw) {
content = parsed.asPartiallyGenerated()
} else {
// Skip invalid partial JSON until it parses cleanly.
content = nil
}
}

if let content {
continuation.yield(.init(content: content, rawContent: raw))
}
}
}
}
Expand Down Expand Up @@ -451,7 +477,12 @@ public struct GeminiLanguageModel: LanguageModel {

if !tools.isEmpty {
let functionDeclarations: [GeminiFunctionDeclaration] = try tools.map { tool in
try convertToolToGeminiFormat(tool)
let schema = try convertSchemaToGeminiFormat(tool.parameters)
return GeminiFunctionDeclaration(
name: tool.name,
description: tool.description,
parameters: schema
)
}
geminiTools.append(.functionDeclarations(functionDeclarations))
}
Expand All @@ -473,9 +504,18 @@ public struct GeminiLanguageModel: LanguageModel {
}
}

private func createGenerateContentParams(
private func convertSchemaToGeminiFormat(_ schema: GenerationSchema) throws -> JSONSchema {
let resolvedSchema = schema.withResolvedRoot() ?? schema
let encoder = JSONEncoder()
encoder.userInfo[GenerationSchema.omitAdditionalPropertiesKey] = true
let data = try encoder.encode(resolvedSchema)
return try JSONDecoder().decode(JSONSchema.self, from: data)
}

private func createGenerateContentParams<Content: Generable>(
contents: [GeminiContent],
tools: [GeminiTool]?,
generating type: Content.Type,
options: GenerationOptions,
thinking: GeminiLanguageModel.CustomGenerationOptions.Thinking,
jsonMode: GeminiLanguageModel.CustomGenerationOptions.JSONMode?
Expand Down Expand Up @@ -518,7 +558,11 @@ private func createGenerateContentParams(
}
generationConfig["thinkingConfig"] = .object(thinkingConfig)

if let jsonMode {
if type != String.self {
let schema = try convertSchemaToGeminiFormat(type.generationSchema)
generationConfig["responseMimeType"] = .string("application/json")
generationConfig["responseSchema"] = try JSONValue(schema)
} else if let jsonMode {
switch jsonMode {
case .disabled:
break
Expand Down Expand Up @@ -652,19 +696,23 @@ private func resolveFunctionCalls(
return .invocations(results)
}

private func convertToolToGeminiFormat(_ tool: any Tool) throws -> GeminiFunctionDeclaration {
let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters

let encoder = JSONEncoder()
encoder.userInfo[GenerationSchema.omitAdditionalPropertiesKey] = true
let data = try encoder.encode(resolvedSchema)
let schema = try JSONDecoder().decode(JSONSchema.self, from: data)
private func emptyResponseContent<Content: Generable>(
for type: Content.Type
) throws -> (content: Content, rawContent: GeneratedContent) {
if type == String.self {
let raw = GeneratedContent("")
return ("" as! Content, raw)
}

return GeminiFunctionDeclaration(
name: tool.name,
description: tool.description,
parameters: schema
)
let rawEmpty = GeneratedContent(properties: [:])
do {
let content = try type.init(rawEmpty)
return (content, rawEmpty)
} catch {
let rawNull = try GeneratedContent(json: "null")
let content = try type.init(rawNull)
return (content, rawNull)
}
}

private func toGeneratedContent(_ value: [String: JSONValue]?) throws -> GeneratedContent {
Expand Down
89 changes: 89 additions & 0 deletions Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,93 @@ struct GeminiLanguageModelTests {
#expect(response.content.contains("Alice"))
#expect(response.content.contains("30"))
}

@Suite("Structured Output")
struct StructuredOutputTests {
@Generable
struct Person {
@Guide(description: "The person's full name")
var name: String

@Guide(description: "The person's age in years")
var age: Int

@Guide(description: "The person's email address")
var email: String?
}

@Generable
struct Book {
@Guide(description: "The book's title")
var title: String

@Guide(description: "The book's author")
var author: String

@Guide(description: "The publication year")
var year: Int
}

private var model: GeminiLanguageModel {
GeminiLanguageModel(apiKey: geminiAPIKey!, model: "gemini-2.5-flash")
}

@Test func basicStructuredOutput() async throws {
let session = LanguageModelSession(model: model)
let response = try await session.respond(
to: "Generate a person named John Doe, age 30, email john@example.com",
generating: Person.self
)

#expect(!response.content.name.isEmpty)
#expect(response.content.name.contains("John") || response.content.name.contains("Doe"))
#expect(response.content.age > 0)
#expect(response.content.age <= 100)
#expect(response.content.email != nil)
}

@Test func structuredOutputWithOptionalField() async throws {
let session = LanguageModelSession(model: model)
let response = try await session.respond(
to: "Generate a person named Jane Smith, age 25, with no email",
generating: Person.self
)

#expect(!response.content.name.isEmpty)
#expect(response.content.name.contains("Jane") || response.content.name.contains("Smith"))
#expect(response.content.age > 0)
#expect(response.content.age <= 100)
#expect(response.content.email == nil || response.content.email?.isEmpty == true)
}

@Test func structuredOutputWithNestedTypes() async throws {
let session = LanguageModelSession(model: model)
let response = try await session.respond(
to: "Generate a book titled 'The Swift Programming Language' by 'Apple Inc.' published in 2024",
generating: Book.self
)

#expect(!response.content.title.isEmpty)
#expect(!response.content.author.isEmpty)
#expect(response.content.year >= 2020)
}

@Test func streamingStructuredOutput() async throws {
let session = LanguageModelSession(model: model)
let stream = session.streamResponse(
to: "Generate a person named Alice, age 28, email alice@example.com",
generating: Person.self
)

var snapshots: [LanguageModelSession.ResponseStream<Person>.Snapshot] = []
for try await snapshot in stream {
snapshots.append(snapshot)
}

#expect(!snapshots.isEmpty)
let finalSnapshot = snapshots.last!
#expect((finalSnapshot.content.name?.isEmpty ?? true) == false)
#expect((finalSnapshot.content.age ?? 0) > 0)
}
}
}