Skip to content

Support vision models and function calling #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions Sources/Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,22 @@ func parse(tokens: [Token]) throws -> Program {
try StringLiteral(value: expect(type: .text, error: "Expected text token").value)
}

func parseCallExpression(callee: Expression) throws -> CallExpression {
func parseCallExpression(callee: Expression) throws -> Expression {
let args = try parseArgs()
var callExpression = CallExpression(callee: callee, args: args)
var expression: Expression = CallExpression(callee: callee, args: args)
// Handle potential array indexing after method call
if typeof(.openSquareBracket) {
expression = MemberExpression(
object: expression,
property: try parseMemberExpressionArgumentsList(),
computed: true
)
}
// Handle potential chained method calls
if typeof(.openParen) {
callExpression = try parseCallExpression(callee: callExpression)
expression = try parseCallExpression(callee: expression)
}
return callExpression
return expression
}

func parseMemberExpressionArgumentsList() throws -> Expression {
Expand All @@ -73,7 +82,19 @@ func parse(tokens: [Token]) throws -> Program {
current += 1 // consume colon
isSlice = true
} else {
slices.append(try parseExpression())
// Handle negative numbers as indices
if typeof(.additiveBinaryOperator) && tokens[current].value == "-" {
current += 1 // consume the minus sign
if typeof(.numericLiteral) {
let num = tokens[current].value
current += 1
slices.append(NumericLiteral(value: -Int(num)!))
} else {
throw JinjaError.syntax("Expected number after minus sign in array index")
}
} else {
slices.append(try parseExpression())
}
if typeof(.colon) {
current += 1 // consume colon
isSlice = true
Expand Down Expand Up @@ -111,6 +132,23 @@ func parse(tokens: [Token]) throws -> Program {
if !(property is Identifier) {
throw JinjaError.syntax("Expected identifier following dot operator")
}
// Handle method calls
if typeof(.openParen) {
let methodCall = CallExpression(
callee: MemberExpression(object: object, property: property, computed: false),
args: try parseArgs()
)
// Handle array indexing after method call
if typeof(.openSquareBracket) {
current += 1 // consume [
let index = try parseExpression()
try expect(type: .closeSquareBracket, error: "Expected closing square bracket")
object = MemberExpression(object: methodCall, property: index, computed: true)
continue
}
object = methodCall
continue
}
}
object = MemberExpression(
object: object,
Expand Down Expand Up @@ -364,12 +402,14 @@ func parse(tokens: [Token]) throws -> Program {
func parseSetStatement() throws -> Statement {
let left = try parseExpression()
if typeof(.equals) {
current += 1
current += 1 // consume equals
// Parse the right-hand side as an expression
let value = try parseExpression()
// Explicitly cast 'value' to 'Expression'
try expect(type: .closeStatement, error: "Expected closing statement token")
return Set(assignee: left, value: value)
}
// If there's no equals sign, treat it as an expression statement
try expect(type: .closeStatement, error: "Expected closing statement token")
return left
}

Expand Down Expand Up @@ -552,11 +592,11 @@ func parse(tokens: [Token]) throws -> Program {
// Consume {% %} tokens
try expect(type: .openStatement, error: "Expected opening statement token")
var result: Statement

switch tokens[current].type {
case .set:
current += 1 // consume 'set' token
result = try parseSetStatement()
try expect(type: .closeStatement, error: "Expected closing statement token")
case .if:
current += 1 // consume 'if' token
result = try parseIfStatement()
Expand All @@ -576,8 +616,11 @@ func parse(tokens: [Token]) throws -> Program {
try expect(type: .endFor, error: "Expected endfor token")
try expect(type: .closeStatement, error: "Expected %} token")
default:
throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)")
// Handle expressions within statements
result = try parseExpression()
try expect(type: .closeStatement, error: "Expected closing statement token")
}

return result
}

Expand Down
7 changes: 3 additions & 4 deletions Sources/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -850,10 +850,9 @@ struct Interpreter {
} else if let object = object as? ArrayValue {
if let property = property as? NumericValue {
if let index = property.value as? Int {
if index >= 0 && index < object.value.count {
value = object.value[index]
} else if index < 0 && index >= -object.value.count {
value = object.value[object.value.count + index]
let actualIndex = index < 0 ? object.value.count + index : index
if actualIndex >= 0 && actualIndex < object.value.count {
value = object.value[actualIndex]
} else {
value = UndefinedValue()
}
Expand Down
8 changes: 1 addition & 7 deletions Tests/FilterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,7 @@ final class FilterTests: XCTestCase {
try runTest(filterName: "max", input: ["b", "a", "d", "c"], expected: "d")
try runTest(filterName: "max", input: [], expected: UndefinedValue())

// Test pprint
try runTest(
filterName: "pprint",
input: [1, 2, 3],
expected: "\(ArrayValue(value: [NumericValue(value: 1), NumericValue(value: 2), NumericValue(value: 3)]))"
)
try runTest(filterName: "pprint", input: "a", expected: "\(StringValue(value: "a"))")
// TODO: Figure out how to test "pprint", given that Swift 5.10 doesn't preserve the key order in dictionaries

// TODO: Figure out how to test "random" filter

Expand Down
22 changes: 20 additions & 2 deletions Tests/Templates/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,29 @@ final class ChatTemplateTests: XCTestCase {
"bos_token": "<|begin_of_text|>",
"add_generation_prompt": true,
])
print("::: result:")
print(result)
let target = """
<|im_start|>user<|im_sep|>What is the weather in Paris today?<|im_end|><|im_start|>assistant<|im_sep|>
"""
XCTAssertEqual(result, target)
}

func testDeepSeekQwen() throws {
let userMessage = [
"role": "user",
"content": "What is the weather in Paris today?",
]
let chatTemplate = """
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
"""
let template = try Template(chatTemplate)
let result = try template.render([
"messages": [userMessage],
"bos_token": "<|begin_of_text|>",
"add_generation_prompt": true,
])
let target = """
<|begin_of_text|><|User|>What is the weather in Paris today?<|Assistant|>
"""
XCTAssertEqual(result, target)
}
}