Skip to content

[mypyc] Add faster primitive for string equality #19402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@
unary_ops,
)
from mypyc.primitives.set_ops import new_set_op
from mypyc.primitives.str_ops import str_check_if_true, str_ssize_t_size_op, unicode_compare
from mypyc.primitives.str_ops import (
str_check_if_true,
str_eq,
str_ssize_t_size_op,
unicode_compare,
)
from mypyc.primitives.tuple_ops import list_tuple_op, new_tuple_op, new_tuple_with_length_op
from mypyc.rt_subtype import is_runtime_subtype
from mypyc.sametype import is_same_type
Expand Down Expand Up @@ -1471,6 +1476,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
if op == "==":
return self.primitive_op(str_eq, [lhs, rhs], line)
elif op == "!=":
eq = self.primitive_op(str_eq, [lhs, rhs], line)
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
error_constant = Integer(-1, c_int_rprimitive, line)
compare_error_check = self.add(
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
#define RIGHTSTRIP 1
#define BOTHSTRIP 2

char CPyStr_Equal(PyObject *str1, PyObject *str2);
PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
Expand Down
16 changes: 16 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
#undef BLOOM_UPDATE
}

// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
if (str1 == str2) {
return 1;
}
Py_ssize_t len = PyUnicode_GET_LENGTH(str1);
if (PyUnicode_GET_LENGTH(str2) != len)
return 0;
int kind = PyUnicode_KIND(str1);
if (PyUnicode_KIND(str2) != kind)
return 0;
const void *data1 = PyUnicode_DATA(str1);
const void *data2 = PyUnicode_DATA(str2);
return memcmp(data1, data2, len * kind) == 0;
}

PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
if (PyUnicode_READY(str) != -1) {
if (CPyTagged_CheckShort(index)) {
Expand Down
10 changes: 10 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ERR_NEG_INT,
binary_op,
custom_op,
custom_primitive_op,
function_op,
load_address_op,
method_op,
Expand Down Expand Up @@ -69,6 +70,15 @@
steals=[True, False],
)

# str1 == str2 (very common operation, so we provide our own)
str_eq = custom_primitive_op(
name="str_eq",
c_function_name="CPyStr_Equal",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

unicode_compare = custom_op(
arg_types=[str_rprimitive, str_rprimitive],
return_type=c_int_rprimitive,
Expand Down
31 changes: 9 additions & 22 deletions mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,9 @@ def typeddict(d):
r9, k :: str
v :: object
r10 :: str
r11 :: i32
r12 :: bit
r13 :: object
r14, r15, r16 :: bit
r11 :: bool
name :: object
r17, r18 :: bit
r12, r13 :: bit
L0:
r0 = 0
r1 = PyDict_Size(d)
Expand All @@ -415,35 +412,25 @@ L1:
r5 = r4[1]
r0 = r5
r6 = r4[0]
if r6 goto L2 else goto L9 :: bool
if r6 goto L2 else goto L6 :: bool
L2:
r7 = r4[2]
r8 = r4[3]
r9 = cast(str, r7)
k = r9
v = r8
r10 = 'name'
r11 = PyUnicode_Compare(k, r10)
r12 = r11 == -1
if r12 goto L3 else goto L5 :: bool
r11 = CPyStr_Equal(k, r10)
if r11 goto L3 else goto L4 :: bool
L3:
r13 = PyErr_Occurred()
r14 = r13 != 0
if r14 goto L4 else goto L5 :: bool
name = v
L4:
r15 = CPy_KeepPropagating()
L5:
r16 = r11 == 0
if r16 goto L6 else goto L7 :: bool
r12 = CPyDict_CheckSize(d, r2)
goto L1
L6:
name = v
r13 = CPy_NoErrOccurred()
L7:
L8:
r17 = CPyDict_CheckSize(d, r2)
goto L1
L9:
r18 = CPy_NoErrOccurred()
L10:
return 1

[case testDictLoadAddress]
Expand Down
38 changes: 7 additions & 31 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -65,42 +65,18 @@ def neq(x: str, y: str) -> bool:
[out]
def eq(x, y):
x, y :: str
r0 :: i32
r1 :: bit
r2 :: object
r3, r4, r5 :: bit
r0 :: bool
L0:
r0 = PyUnicode_Compare(x, y)
r1 = r0 == -1
if r1 goto L1 else goto L3 :: bool
L1:
r2 = PyErr_Occurred()
r3 = r2 != 0
if r3 goto L2 else goto L3 :: bool
L2:
r4 = CPy_KeepPropagating()
L3:
r5 = r0 == 0
return r5
r0 = CPyStr_Equal(x, y)
return r0
def neq(x, y):
x, y :: str
r0 :: i32
r0 :: bool
r1 :: bit
r2 :: object
r3, r4, r5 :: bit
L0:
r0 = PyUnicode_Compare(x, y)
r1 = r0 == -1
if r1 goto L1 else goto L3 :: bool
L1:
r2 = PyErr_Occurred()
r3 = r2 != 0
if r3 goto L2 else goto L3 :: bool
L2:
r4 = CPy_KeepPropagating()
L3:
r5 = r0 != 0
return r5
r0 = CPyStr_Equal(x, y)
r1 = r0 == 0
return r1

[case testStrReplace]
from typing import Optional
Expand Down
76 changes: 24 additions & 52 deletions mypyc/test-data/irbuild-unreachable.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,27 @@ def f():
r1 :: str
r2 :: object
r3, r4 :: str
r5 :: i32
r6 :: bit
r7 :: object
r8, r9, r10 :: bit
r11, r12 :: bool
r13 :: object
r14, y :: bool
r5, r6, r7 :: bool
r8 :: object
r9, y :: bool
L0:
r0 = sys :: module
r1 = 'platform'
r2 = CPyObject_GetAttr(r0, r1)
r3 = cast(str, r2)
r4 = 'x'
r5 = PyUnicode_Compare(r3, r4)
r6 = r5 == -1
if r6 goto L1 else goto L3 :: bool
r5 = CPyStr_Equal(r3, r4)
if r5 goto L2 else goto L1 :: bool
L1:
r7 = PyErr_Occurred()
r8 = r7 != 0
if r8 goto L2 else goto L3 :: bool
r6 = r5
goto L3
L2:
r9 = CPy_KeepPropagating()
r7 = raise RuntimeError('mypyc internal error: should be unreachable')
r8 = box(None, 1)
r9 = unbox(bool, r8)
r6 = r9
L3:
r10 = r5 == 0
if r10 goto L5 else goto L4 :: bool
L4:
r11 = r10
goto L6
L5:
r12 = raise RuntimeError('mypyc internal error: should be unreachable')
r13 = box(None, 1)
r14 = unbox(bool, r13)
r11 = r14
L6:
y = r11
y = r6
return 1

[case testUnreachableNameExpr]
Expand All @@ -59,41 +45,27 @@ def f():
r1 :: str
r2 :: object
r3, r4 :: str
r5 :: i32
r6 :: bit
r7 :: object
r8, r9, r10 :: bit
r11, r12 :: bool
r13 :: object
r14, y :: bool
r5, r6, r7 :: bool
r8 :: object
r9, y :: bool
L0:
r0 = sys :: module
r1 = 'platform'
r2 = CPyObject_GetAttr(r0, r1)
r3 = cast(str, r2)
r4 = 'x'
r5 = PyUnicode_Compare(r3, r4)
r6 = r5 == -1
if r6 goto L1 else goto L3 :: bool
r5 = CPyStr_Equal(r3, r4)
if r5 goto L2 else goto L1 :: bool
L1:
r7 = PyErr_Occurred()
r8 = r7 != 0
if r8 goto L2 else goto L3 :: bool
r6 = r5
goto L3
L2:
r9 = CPy_KeepPropagating()
r7 = raise RuntimeError('mypyc internal error: should be unreachable')
r8 = box(None, 1)
r9 = unbox(bool, r8)
r6 = r9
L3:
r10 = r5 == 0
if r10 goto L5 else goto L4 :: bool
L4:
r11 = r10
goto L6
L5:
r12 = raise RuntimeError('mypyc internal error: should be unreachable')
r13 = box(None, 1)
r14 = unbox(bool, r13)
r11 = r14
L6:
y = r11
y = r6
return 1

[case testUnreachableStatementAfterReturn]
Expand Down