Skip to content

Commit 30edf9f

Browse files
committed
Handle DeepSeek R1 Qwen chat template
1 parent e9822c2 commit 30edf9f

File tree

4 files changed

+76
-22
lines changed

4 files changed

+76
-22
lines changed

Sources/Parser.swift

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,22 @@ func parse(tokens: [Token]) throws -> Program {
5555
try StringLiteral(value: expect(type: .text, error: "Expected text token").value)
5656
}
5757

58-
func parseCallExpression(callee: Expression) throws -> CallExpression {
58+
func parseCallExpression(callee: Expression) throws -> Expression {
5959
let args = try parseArgs()
60-
var callExpression = CallExpression(callee: callee, args: args)
60+
var expression: Expression = CallExpression(callee: callee, args: args)
61+
// Handle potential array indexing after method call
62+
if typeof(.openSquareBracket) {
63+
expression = MemberExpression(
64+
object: expression,
65+
property: try parseMemberExpressionArgumentsList(),
66+
computed: true
67+
)
68+
}
69+
// Handle potential chained method calls
6170
if typeof(.openParen) {
62-
callExpression = try parseCallExpression(callee: callExpression)
71+
expression = try parseCallExpression(callee: expression)
6372
}
64-
return callExpression
73+
return expression
6574
}
6675

6776
func parseMemberExpressionArgumentsList() throws -> Expression {
@@ -73,7 +82,19 @@ func parse(tokens: [Token]) throws -> Program {
7382
current += 1 // consume colon
7483
isSlice = true
7584
} else {
76-
slices.append(try parseExpression())
85+
// Handle negative numbers as indices
86+
if typeof(.additiveBinaryOperator) && tokens[current].value == "-" {
87+
current += 1 // consume the minus sign
88+
if typeof(.numericLiteral) {
89+
let num = tokens[current].value
90+
current += 1
91+
slices.append(NumericLiteral(value: -Int(num)!))
92+
} else {
93+
throw JinjaError.syntax("Expected number after minus sign in array index")
94+
}
95+
} else {
96+
slices.append(try parseExpression())
97+
}
7798
if typeof(.colon) {
7899
current += 1 // consume colon
79100
isSlice = true
@@ -111,6 +132,23 @@ func parse(tokens: [Token]) throws -> Program {
111132
if !(property is Identifier) {
112133
throw JinjaError.syntax("Expected identifier following dot operator")
113134
}
135+
// Handle method calls
136+
if typeof(.openParen) {
137+
let methodCall = CallExpression(
138+
callee: MemberExpression(object: object, property: property, computed: false),
139+
args: try parseArgs()
140+
)
141+
// Handle array indexing after method call
142+
if typeof(.openSquareBracket) {
143+
current += 1 // consume [
144+
let index = try parseExpression()
145+
try expect(type: .closeSquareBracket, error: "Expected closing square bracket")
146+
object = MemberExpression(object: methodCall, property: index, computed: true)
147+
continue
148+
}
149+
object = methodCall
150+
continue
151+
}
114152
}
115153
object = MemberExpression(
116154
object: object,
@@ -364,12 +402,14 @@ func parse(tokens: [Token]) throws -> Program {
364402
func parseSetStatement() throws -> Statement {
365403
let left = try parseExpression()
366404
if typeof(.equals) {
367-
current += 1
405+
current += 1 // consume equals
368406
// Parse the right-hand side as an expression
369407
let value = try parseExpression()
370-
// Explicitly cast 'value' to 'Expression'
408+
try expect(type: .closeStatement, error: "Expected closing statement token")
371409
return Set(assignee: left, value: value)
372410
}
411+
// If there's no equals sign, treat it as an expression statement
412+
try expect(type: .closeStatement, error: "Expected closing statement token")
373413
return left
374414
}
375415

@@ -552,11 +592,11 @@ func parse(tokens: [Token]) throws -> Program {
552592
// Consume {% %} tokens
553593
try expect(type: .openStatement, error: "Expected opening statement token")
554594
var result: Statement
595+
555596
switch tokens[current].type {
556597
case .set:
557598
current += 1 // consume 'set' token
558599
result = try parseSetStatement()
559-
try expect(type: .closeStatement, error: "Expected closing statement token")
560600
case .if:
561601
current += 1 // consume 'if' token
562602
result = try parseIfStatement()
@@ -576,8 +616,11 @@ func parse(tokens: [Token]) throws -> Program {
576616
try expect(type: .endFor, error: "Expected endfor token")
577617
try expect(type: .closeStatement, error: "Expected %} token")
578618
default:
579-
throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)")
619+
// Handle expressions within statements
620+
result = try parseExpression()
621+
try expect(type: .closeStatement, error: "Expected closing statement token")
580622
}
623+
581624
return result
582625
}
583626

Sources/Runtime.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,10 +850,9 @@ struct Interpreter {
850850
} else if let object = object as? ArrayValue {
851851
if let property = property as? NumericValue {
852852
if let index = property.value as? Int {
853-
if index >= 0 && index < object.value.count {
854-
value = object.value[index]
855-
} else if index < 0 && index >= -object.value.count {
856-
value = object.value[object.value.count + index]
853+
let actualIndex = index < 0 ? object.value.count + index : index
854+
if actualIndex >= 0 && actualIndex < object.value.count {
855+
value = object.value[actualIndex]
857856
} else {
858857
value = UndefinedValue()
859858
}

Tests/FilterTests.swift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -812,13 +812,7 @@ final class FilterTests: XCTestCase {
812812
try runTest(filterName: "max", input: ["b", "a", "d", "c"], expected: "d")
813813
try runTest(filterName: "max", input: [], expected: UndefinedValue())
814814

815-
// Test pprint
816-
try runTest(
817-
filterName: "pprint",
818-
input: [1, 2, 3],
819-
expected: "\(ArrayValue(value: [NumericValue(value: 1), NumericValue(value: 2), NumericValue(value: 3)]))"
820-
)
821-
try runTest(filterName: "pprint", input: "a", expected: "\(StringValue(value: "a"))")
815+
// TODO: Figure out how to test "pprint", given that Swift 5.10 doesn't preserve the key order in dictionaries
822816

823817
// TODO: Figure out how to test "random" filter
824818

Tests/Templates/ChatTemplateTests.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,29 @@ final class ChatTemplateTests: XCTestCase {
594594
"bos_token": "<|begin_of_text|>",
595595
"add_generation_prompt": true,
596596
])
597-
print("::: result:")
598-
print(result)
599597
let target = """
600598
<|im_start|>user<|im_sep|>What is the weather in Paris today?<|im_end|><|im_start|>assistant<|im_sep|>
601599
"""
602600
XCTAssertEqual(result, target)
603601
}
602+
603+
func testDeepSeekQwen() throws {
604+
let userMessage = [
605+
"role": "user",
606+
"content": "What is the weather in Paris today?",
607+
]
608+
let chatTemplate = """
609+
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
610+
"""
611+
let template = try Template(chatTemplate)
612+
let result = try template.render([
613+
"messages": [userMessage],
614+
"bos_token": "<|begin_of_text|>",
615+
"add_generation_prompt": true,
616+
])
617+
let target = """
618+
<|begin_of_text|><|User|>What is the weather in Paris today?<|Assistant|>
619+
"""
620+
XCTAssertEqual(result, target)
621+
}
604622
}

0 commit comments

Comments
 (0)