diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index be4178a4a71a..e25079c1146b 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -175,7 +175,12 @@ unary_ops, ) from mypyc.primitives.set_ops import new_set_op -from mypyc.primitives.str_ops import str_check_if_true, str_ssize_t_size_op, unicode_compare +from mypyc.primitives.str_ops import ( + str_check_if_true, + str_eq, + str_ssize_t_size_op, + unicode_compare, +) from mypyc.primitives.tuple_ops import list_tuple_op, new_tuple_op, new_tuple_with_length_op from mypyc.rt_subtype import is_runtime_subtype from mypyc.sametype import is_same_type @@ -1471,6 +1476,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: """Compare two strings""" + if op == "==": + return self.primitive_op(str_eq, [lhs, rhs], line) + elif op == "!=": + eq = self.primitive_op(str_eq, [lhs, rhs], line) + return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line)) compare_result = self.call_c(unicode_compare, [lhs, rhs], line) error_constant = Integer(-1, c_int_rprimitive, line) compare_error_check = self.add( diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index bdf3e0130a4c..a0f1b06cc0d5 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -726,6 +726,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { #define RIGHTSTRIP 1 #define BOTHSTRIP 2 +char CPyStr_Equal(PyObject *str1, PyObject *str2); PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 210172c57497..5fd376f21cfa 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -64,6 +64,22 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len) #undef BLOOM_UPDATE } +// Adapted from CPython 3.13.1 (_PyUnicode_Equal) +char CPyStr_Equal(PyObject *str1, PyObject *str2) { + if (str1 == str2) { + return 1; + } + Py_ssize_t len = PyUnicode_GET_LENGTH(str1); + if (PyUnicode_GET_LENGTH(str2) != len) + return 0; + int kind = PyUnicode_KIND(str1); + if (PyUnicode_KIND(str2) != kind) + return 0; + const void *data1 = PyUnicode_DATA(str1); + const void *data2 = PyUnicode_DATA(str2); + return memcmp(data1, data2, len * kind) == 0; +} + PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) { if (PyUnicode_READY(str) != -1) { if (CPyTagged_CheckShort(index)) { diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 9d46da9c3514..37dbdf21bb5d 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -21,6 +21,7 @@ ERR_NEG_INT, binary_op, custom_op, + custom_primitive_op, function_op, load_address_op, method_op, @@ -69,6 +70,15 @@ steals=[True, False], ) +# str1 == str2 (very common operation, so we provide our own) +str_eq = custom_primitive_op( + name="str_eq", + c_function_name="CPyStr_Equal", + arg_types=[str_rprimitive, str_rprimitive], + return_type=bool_rprimitive, + error_kind=ERR_NEVER, +) + unicode_compare = custom_op( arg_types=[str_rprimitive, str_rprimitive], return_type=c_int_rprimitive, diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index a71f5aa2d8a2..cacb14dae273 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -399,12 +399,9 @@ def typeddict(d): r9, k :: str v :: object r10 :: str - r11 :: i32 - r12 :: bit - r13 :: object - r14, r15, r16 :: bit + r11 :: bool name :: object - r17, r18 :: bit + r12, r13 :: bit L0: r0 = 0 r1 = PyDict_Size(d) @@ -415,7 +412,7 @@ L1: r5 = r4[1] r0 = r5 r6 = r4[0] - if r6 goto L2 else goto L9 :: bool + if r6 goto L2 else goto L6 :: bool L2: r7 = r4[2] r8 = r4[3] @@ -423,27 +420,17 @@ L2: k = r9 v = r8 r10 = 'name' - r11 = PyUnicode_Compare(k, r10) - r12 = r11 == -1 - if r12 goto L3 else goto L5 :: bool + r11 = CPyStr_Equal(k, r10) + if r11 goto L3 else goto L4 :: bool L3: - r13 = PyErr_Occurred() - r14 = r13 != 0 - if r14 goto L4 else goto L5 :: bool + name = v L4: - r15 = CPy_KeepPropagating() L5: - r16 = r11 == 0 - if r16 goto L6 else goto L7 :: bool + r12 = CPyDict_CheckSize(d, r2) + goto L1 L6: - name = v + r13 = CPy_NoErrOccurred() L7: -L8: - r17 = CPyDict_CheckSize(d, r2) - goto L1 -L9: - r18 = CPy_NoErrOccurred() -L10: return 1 [case testDictLoadAddress] diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 2bf77a6cb556..4a4992d41a5d 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -65,42 +65,18 @@ def neq(x: str, y: str) -> bool: [out] def eq(x, y): x, y :: str - r0 :: i32 - r1 :: bit - r2 :: object - r3, r4, r5 :: bit + r0 :: bool L0: - r0 = PyUnicode_Compare(x, y) - r1 = r0 == -1 - if r1 goto L1 else goto L3 :: bool -L1: - r2 = PyErr_Occurred() - r3 = r2 != 0 - if r3 goto L2 else goto L3 :: bool -L2: - r4 = CPy_KeepPropagating() -L3: - r5 = r0 == 0 - return r5 + r0 = CPyStr_Equal(x, y) + return r0 def neq(x, y): x, y :: str - r0 :: i32 + r0 :: bool r1 :: bit - r2 :: object - r3, r4, r5 :: bit L0: - r0 = PyUnicode_Compare(x, y) - r1 = r0 == -1 - if r1 goto L1 else goto L3 :: bool -L1: - r2 = PyErr_Occurred() - r3 = r2 != 0 - if r3 goto L2 else goto L3 :: bool -L2: - r4 = CPy_KeepPropagating() -L3: - r5 = r0 != 0 - return r5 + r0 = CPyStr_Equal(x, y) + r1 = r0 == 0 + return r1 [case testStrReplace] from typing import Optional diff --git a/mypyc/test-data/irbuild-unreachable.test b/mypyc/test-data/irbuild-unreachable.test index cebd4582923b..a4f1ef8c7dba 100644 --- a/mypyc/test-data/irbuild-unreachable.test +++ b/mypyc/test-data/irbuild-unreachable.test @@ -11,41 +11,27 @@ def f(): r1 :: str r2 :: object r3, r4 :: str - r5 :: i32 - r6 :: bit - r7 :: object - r8, r9, r10 :: bit - r11, r12 :: bool - r13 :: object - r14, y :: bool + r5, r6, r7 :: bool + r8 :: object + r9, y :: bool L0: r0 = sys :: module r1 = 'platform' r2 = CPyObject_GetAttr(r0, r1) r3 = cast(str, r2) r4 = 'x' - r5 = PyUnicode_Compare(r3, r4) - r6 = r5 == -1 - if r6 goto L1 else goto L3 :: bool + r5 = CPyStr_Equal(r3, r4) + if r5 goto L2 else goto L1 :: bool L1: - r7 = PyErr_Occurred() - r8 = r7 != 0 - if r8 goto L2 else goto L3 :: bool + r6 = r5 + goto L3 L2: - r9 = CPy_KeepPropagating() + r7 = raise RuntimeError('mypyc internal error: should be unreachable') + r8 = box(None, 1) + r9 = unbox(bool, r8) + r6 = r9 L3: - r10 = r5 == 0 - if r10 goto L5 else goto L4 :: bool -L4: - r11 = r10 - goto L6 -L5: - r12 = raise RuntimeError('mypyc internal error: should be unreachable') - r13 = box(None, 1) - r14 = unbox(bool, r13) - r11 = r14 -L6: - y = r11 + y = r6 return 1 [case testUnreachableNameExpr] @@ -59,41 +45,27 @@ def f(): r1 :: str r2 :: object r3, r4 :: str - r5 :: i32 - r6 :: bit - r7 :: object - r8, r9, r10 :: bit - r11, r12 :: bool - r13 :: object - r14, y :: bool + r5, r6, r7 :: bool + r8 :: object + r9, y :: bool L0: r0 = sys :: module r1 = 'platform' r2 = CPyObject_GetAttr(r0, r1) r3 = cast(str, r2) r4 = 'x' - r5 = PyUnicode_Compare(r3, r4) - r6 = r5 == -1 - if r6 goto L1 else goto L3 :: bool + r5 = CPyStr_Equal(r3, r4) + if r5 goto L2 else goto L1 :: bool L1: - r7 = PyErr_Occurred() - r8 = r7 != 0 - if r8 goto L2 else goto L3 :: bool + r6 = r5 + goto L3 L2: - r9 = CPy_KeepPropagating() + r7 = raise RuntimeError('mypyc internal error: should be unreachable') + r8 = box(None, 1) + r9 = unbox(bool, r8) + r6 = r9 L3: - r10 = r5 == 0 - if r10 goto L5 else goto L4 :: bool -L4: - r11 = r10 - goto L6 -L5: - r12 = raise RuntimeError('mypyc internal error: should be unreachable') - r13 = box(None, 1) - r14 = unbox(bool, r13) - r11 = r14 -L6: - y = r11 + y = r6 return 1 [case testUnreachableStatementAfterReturn]