Skip to content

Commit 10f95e6

Browse files
authored
[mypyc] Add faster primitive for string equality (#19402)
This speeds up self check by ~1.4%. String equality is one of the top five most common primitive function calls in self check. We previously used a string comparison primitive that calculated the relative order of two strings. Usually we only care about equality, which we can do quicker since we can fast path using a length check, for example. I checked the CPython implementation of string equality in 3.9 (lowest supported Python version) and 3.13, and both of them had a fast path based on string object kind, and equality checks overall had the same semantics. Current CPython implementation: https://github.com/python/cpython/blob/main/Objects/stringlib/eq.h Tests for this were added in #19401.
1 parent 4a427e9 commit 10f95e6

File tree

7 files changed

+78
-106
lines changed

7 files changed

+78
-106
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,12 @@
175175
unary_ops,
176176
)
177177
from mypyc.primitives.set_ops import new_set_op
178-
from mypyc.primitives.str_ops import str_check_if_true, str_ssize_t_size_op, unicode_compare
178+
from mypyc.primitives.str_ops import (
179+
str_check_if_true,
180+
str_eq,
181+
str_ssize_t_size_op,
182+
unicode_compare,
183+
)
179184
from mypyc.primitives.tuple_ops import list_tuple_op, new_tuple_op, new_tuple_with_length_op
180185
from mypyc.rt_subtype import is_runtime_subtype
181186
from mypyc.sametype import is_same_type
@@ -1471,6 +1476,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
14711476

14721477
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
14731478
"""Compare two strings"""
1479+
if op == "==":
1480+
return self.primitive_op(str_eq, [lhs, rhs], line)
1481+
elif op == "!=":
1482+
eq = self.primitive_op(str_eq, [lhs, rhs], line)
1483+
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))
14741484
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
14751485
error_constant = Integer(-1, c_int_rprimitive, line)
14761486
compare_error_check = self.add(

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
726726
#define RIGHTSTRIP 1
727727
#define BOTHSTRIP 2
728728

729+
char CPyStr_Equal(PyObject *str1, PyObject *str2);
729730
PyObject *CPyStr_Build(Py_ssize_t len, ...);
730731
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
731732
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);

mypyc/lib-rt/str_ops.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
6464
#undef BLOOM_UPDATE
6565
}
6666

67+
// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
68+
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
69+
if (str1 == str2) {
70+
return 1;
71+
}
72+
Py_ssize_t len = PyUnicode_GET_LENGTH(str1);
73+
if (PyUnicode_GET_LENGTH(str2) != len)
74+
return 0;
75+
int kind = PyUnicode_KIND(str1);
76+
if (PyUnicode_KIND(str2) != kind)
77+
return 0;
78+
const void *data1 = PyUnicode_DATA(str1);
79+
const void *data2 = PyUnicode_DATA(str2);
80+
return memcmp(data1, data2, len * kind) == 0;
81+
}
82+
6783
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
6884
if (PyUnicode_READY(str) != -1) {
6985
if (CPyTagged_CheckShort(index)) {

mypyc/primitives/str_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ERR_NEG_INT,
2222
binary_op,
2323
custom_op,
24+
custom_primitive_op,
2425
function_op,
2526
load_address_op,
2627
method_op,
@@ -69,6 +70,15 @@
6970
steals=[True, False],
7071
)
7172

73+
# str1 == str2 (very common operation, so we provide our own)
74+
str_eq = custom_primitive_op(
75+
name="str_eq",
76+
c_function_name="CPyStr_Equal",
77+
arg_types=[str_rprimitive, str_rprimitive],
78+
return_type=bool_rprimitive,
79+
error_kind=ERR_NEVER,
80+
)
81+
7282
unicode_compare = custom_op(
7383
arg_types=[str_rprimitive, str_rprimitive],
7484
return_type=c_int_rprimitive,

mypyc/test-data/irbuild-dict.test

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,9 @@ def typeddict(d):
399399
r9, k :: str
400400
v :: object
401401
r10 :: str
402-
r11 :: i32
403-
r12 :: bit
404-
r13 :: object
405-
r14, r15, r16 :: bit
402+
r11 :: bool
406403
name :: object
407-
r17, r18 :: bit
404+
r12, r13 :: bit
408405
L0:
409406
r0 = 0
410407
r1 = PyDict_Size(d)
@@ -415,35 +412,25 @@ L1:
415412
r5 = r4[1]
416413
r0 = r5
417414
r6 = r4[0]
418-
if r6 goto L2 else goto L9 :: bool
415+
if r6 goto L2 else goto L6 :: bool
419416
L2:
420417
r7 = r4[2]
421418
r8 = r4[3]
422419
r9 = cast(str, r7)
423420
k = r9
424421
v = r8
425422
r10 = 'name'
426-
r11 = PyUnicode_Compare(k, r10)
427-
r12 = r11 == -1
428-
if r12 goto L3 else goto L5 :: bool
423+
r11 = CPyStr_Equal(k, r10)
424+
if r11 goto L3 else goto L4 :: bool
429425
L3:
430-
r13 = PyErr_Occurred()
431-
r14 = r13 != 0
432-
if r14 goto L4 else goto L5 :: bool
426+
name = v
433427
L4:
434-
r15 = CPy_KeepPropagating()
435428
L5:
436-
r16 = r11 == 0
437-
if r16 goto L6 else goto L7 :: bool
429+
r12 = CPyDict_CheckSize(d, r2)
430+
goto L1
438431
L6:
439-
name = v
432+
r13 = CPy_NoErrOccurred()
440433
L7:
441-
L8:
442-
r17 = CPyDict_CheckSize(d, r2)
443-
goto L1
444-
L9:
445-
r18 = CPy_NoErrOccurred()
446-
L10:
447434
return 1
448435

449436
[case testDictLoadAddress]

mypyc/test-data/irbuild-str.test

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,18 @@ def neq(x: str, y: str) -> bool:
6565
[out]
6666
def eq(x, y):
6767
x, y :: str
68-
r0 :: i32
69-
r1 :: bit
70-
r2 :: object
71-
r3, r4, r5 :: bit
68+
r0 :: bool
7269
L0:
73-
r0 = PyUnicode_Compare(x, y)
74-
r1 = r0 == -1
75-
if r1 goto L1 else goto L3 :: bool
76-
L1:
77-
r2 = PyErr_Occurred()
78-
r3 = r2 != 0
79-
if r3 goto L2 else goto L3 :: bool
80-
L2:
81-
r4 = CPy_KeepPropagating()
82-
L3:
83-
r5 = r0 == 0
84-
return r5
70+
r0 = CPyStr_Equal(x, y)
71+
return r0
8572
def neq(x, y):
8673
x, y :: str
87-
r0 :: i32
74+
r0 :: bool
8875
r1 :: bit
89-
r2 :: object
90-
r3, r4, r5 :: bit
9176
L0:
92-
r0 = PyUnicode_Compare(x, y)
93-
r1 = r0 == -1
94-
if r1 goto L1 else goto L3 :: bool
95-
L1:
96-
r2 = PyErr_Occurred()
97-
r3 = r2 != 0
98-
if r3 goto L2 else goto L3 :: bool
99-
L2:
100-
r4 = CPy_KeepPropagating()
101-
L3:
102-
r5 = r0 != 0
103-
return r5
77+
r0 = CPyStr_Equal(x, y)
78+
r1 = r0 == 0
79+
return r1
10480

10581
[case testStrReplace]
10682
from typing import Optional

mypyc/test-data/irbuild-unreachable.test

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,27 @@ def f():
1111
r1 :: str
1212
r2 :: object
1313
r3, r4 :: str
14-
r5 :: i32
15-
r6 :: bit
16-
r7 :: object
17-
r8, r9, r10 :: bit
18-
r11, r12 :: bool
19-
r13 :: object
20-
r14, y :: bool
14+
r5, r6, r7 :: bool
15+
r8 :: object
16+
r9, y :: bool
2117
L0:
2218
r0 = sys :: module
2319
r1 = 'platform'
2420
r2 = CPyObject_GetAttr(r0, r1)
2521
r3 = cast(str, r2)
2622
r4 = 'x'
27-
r5 = PyUnicode_Compare(r3, r4)
28-
r6 = r5 == -1
29-
if r6 goto L1 else goto L3 :: bool
23+
r5 = CPyStr_Equal(r3, r4)
24+
if r5 goto L2 else goto L1 :: bool
3025
L1:
31-
r7 = PyErr_Occurred()
32-
r8 = r7 != 0
33-
if r8 goto L2 else goto L3 :: bool
26+
r6 = r5
27+
goto L3
3428
L2:
35-
r9 = CPy_KeepPropagating()
29+
r7 = raise RuntimeError('mypyc internal error: should be unreachable')
30+
r8 = box(None, 1)
31+
r9 = unbox(bool, r8)
32+
r6 = r9
3633
L3:
37-
r10 = r5 == 0
38-
if r10 goto L5 else goto L4 :: bool
39-
L4:
40-
r11 = r10
41-
goto L6
42-
L5:
43-
r12 = raise RuntimeError('mypyc internal error: should be unreachable')
44-
r13 = box(None, 1)
45-
r14 = unbox(bool, r13)
46-
r11 = r14
47-
L6:
48-
y = r11
34+
y = r6
4935
return 1
5036

5137
[case testUnreachableNameExpr]
@@ -59,41 +45,27 @@ def f():
5945
r1 :: str
6046
r2 :: object
6147
r3, r4 :: str
62-
r5 :: i32
63-
r6 :: bit
64-
r7 :: object
65-
r8, r9, r10 :: bit
66-
r11, r12 :: bool
67-
r13 :: object
68-
r14, y :: bool
48+
r5, r6, r7 :: bool
49+
r8 :: object
50+
r9, y :: bool
6951
L0:
7052
r0 = sys :: module
7153
r1 = 'platform'
7254
r2 = CPyObject_GetAttr(r0, r1)
7355
r3 = cast(str, r2)
7456
r4 = 'x'
75-
r5 = PyUnicode_Compare(r3, r4)
76-
r6 = r5 == -1
77-
if r6 goto L1 else goto L3 :: bool
57+
r5 = CPyStr_Equal(r3, r4)
58+
if r5 goto L2 else goto L1 :: bool
7859
L1:
79-
r7 = PyErr_Occurred()
80-
r8 = r7 != 0
81-
if r8 goto L2 else goto L3 :: bool
60+
r6 = r5
61+
goto L3
8262
L2:
83-
r9 = CPy_KeepPropagating()
63+
r7 = raise RuntimeError('mypyc internal error: should be unreachable')
64+
r8 = box(None, 1)
65+
r9 = unbox(bool, r8)
66+
r6 = r9
8467
L3:
85-
r10 = r5 == 0
86-
if r10 goto L5 else goto L4 :: bool
87-
L4:
88-
r11 = r10
89-
goto L6
90-
L5:
91-
r12 = raise RuntimeError('mypyc internal error: should be unreachable')
92-
r13 = box(None, 1)
93-
r14 = unbox(bool, r13)
94-
r11 = r14
95-
L6:
96-
y = r11
68+
y = r6
9769
return 1
9870

9971
[case testUnreachableStatementAfterReturn]

0 commit comments

Comments
 (0)