Skip to content

Commit 3df3d79

Browse files
feat: new mypyc primitives for str.count (#19264)
This PR adds new mypyc primitives for all variations of `str.count` fixes mypyc/mypyc#1096
1 parent 5b0ac32 commit 3df3d79

File tree

5 files changed

+203
-0
lines changed

5 files changed

+203
-0
lines changed

β€Žmypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,8 @@ bool CPyStr_IsTrue(PyObject *obj);
753753
Py_ssize_t CPyStr_Size_size_t(PyObject *str);
754754
PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors);
755755
PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
756+
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
757+
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);
756758
CPyTagged CPyStr_Ord(PyObject *obj);
757759

758760

β€Žmypyc/lib-rt/str_ops.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,30 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors) {
511511
}
512512
}
513513

514+
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start) {
515+
Py_ssize_t temp_start = CPyTagged_AsSsize_t(start);
516+
if (temp_start == -1 && PyErr_Occurred()) {
517+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
518+
return -1;
519+
}
520+
Py_ssize_t end = PyUnicode_GET_LENGTH(unicode);
521+
return PyUnicode_Count(unicode, substring, temp_start, end);
522+
}
523+
524+
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end) {
525+
Py_ssize_t temp_start = CPyTagged_AsSsize_t(start);
526+
if (temp_start == -1 && PyErr_Occurred()) {
527+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
528+
return -1;
529+
}
530+
Py_ssize_t temp_end = CPyTagged_AsSsize_t(end);
531+
if (temp_end == -1 && PyErr_Occurred()) {
532+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
533+
return -1;
534+
}
535+
return PyUnicode_Count(unicode, substring, temp_start, temp_end);
536+
}
537+
514538

515539
CPyTagged CPyStr_Ord(PyObject *obj) {
516540
Py_ssize_t s = PyUnicode_GET_LENGTH(obj);

β€Žmypyc/primitives/str_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,34 @@
277277
error_kind=ERR_MAGIC,
278278
)
279279

280+
# str.count(substring)
281+
method_op(
282+
name="count",
283+
arg_types=[str_rprimitive, str_rprimitive],
284+
return_type=c_pyssize_t_rprimitive,
285+
c_function_name="CPyStr_Count",
286+
error_kind=ERR_NEG_INT,
287+
extra_int_constants=[(0, c_pyssize_t_rprimitive)],
288+
)
289+
290+
# str.count(substring, start)
291+
method_op(
292+
name="count",
293+
arg_types=[str_rprimitive, str_rprimitive, int_rprimitive],
294+
return_type=c_pyssize_t_rprimitive,
295+
c_function_name="CPyStr_Count",
296+
error_kind=ERR_NEG_INT,
297+
)
298+
299+
# str.count(substring, start, end)
300+
method_op(
301+
name="count",
302+
arg_types=[str_rprimitive, str_rprimitive, int_rprimitive, int_rprimitive],
303+
return_type=c_pyssize_t_rprimitive,
304+
c_function_name="CPyStr_CountFull",
305+
error_kind=ERR_NEG_INT,
306+
)
307+
280308
# str.replace(old, new)
281309
method_op(
282310
name="replace",

β€Žmypyc/test-data/irbuild-str.test

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,61 @@ L0:
504504
r7 = CPyStr_Strip(s, 0)
505505
r8 = CPyStr_RStrip(s, 0)
506506
return 1
507+
508+
[case testCountAll]
509+
def do_count(s: str) -> int:
510+
return s.count("x") # type: ignore [attr-defined]
511+
[out]
512+
def do_count(s):
513+
s, r0 :: str
514+
r1 :: native_int
515+
r2 :: bit
516+
r3 :: object
517+
r4 :: int
518+
L0:
519+
r0 = 'x'
520+
r1 = CPyStr_Count(s, r0, 0)
521+
r2 = r1 >= 0 :: signed
522+
r3 = box(native_int, r1)
523+
r4 = unbox(int, r3)
524+
return r4
525+
526+
[case testCountStart]
527+
def do_count(s: str, start: int) -> int:
528+
return s.count("x", start) # type: ignore [attr-defined]
529+
[out]
530+
def do_count(s, start):
531+
s :: str
532+
start :: int
533+
r0 :: str
534+
r1 :: native_int
535+
r2 :: bit
536+
r3 :: object
537+
r4 :: int
538+
L0:
539+
r0 = 'x'
540+
r1 = CPyStr_Count(s, r0, start)
541+
r2 = r1 >= 0 :: signed
542+
r3 = box(native_int, r1)
543+
r4 = unbox(int, r3)
544+
return r4
545+
546+
[case testCountStartEnd]
547+
def do_count(s: str, start: int, end: int) -> int:
548+
return s.count("x", start, end) # type: ignore [attr-defined]
549+
[out]
550+
def do_count(s, start, end):
551+
s :: str
552+
start, end :: int
553+
r0 :: str
554+
r1 :: native_int
555+
r2 :: bit
556+
r3 :: object
557+
r4 :: int
558+
L0:
559+
r0 = 'x'
560+
r1 = CPyStr_CountFull(s, r0, start, end)
561+
r2 = r1 >= 0 :: signed
562+
r3 = box(native_int, r1)
563+
r4 = unbox(int, r3)
564+
return r4

β€Žmypyc/test-data/run-strings.test

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,3 +815,94 @@ def test_unicode_range() -> None:
815815
assert "\u2029 \U0010AAAA\U00104444B\u205F ".strip() == "\U0010AAAA\U00104444B"
816816
assert " \u3000\u205F ".strip() == ""
817817
assert "\u2029 \U00102865\u205F ".rstrip() == "\u2029 \U00102865"
818+
819+
[case testCount]
820+
# mypy: disable-error-code="attr-defined"
821+
def test_count() -> None:
822+
string = "abcbcb"
823+
assert string.count("a") == 1
824+
assert string.count("b") == 3
825+
assert string.count("c") == 2
826+
def test_count_start() -> None:
827+
string = "abcbcb"
828+
assert string.count("a", 2) == string.count("a", -4) == 0, (string.count("a", 2), string.count("a", -4))
829+
assert string.count("b", 2) == string.count("b", -4) == 2, (string.count("b", 2), string.count("b", -4))
830+
assert string.count("c", 2) == string.count("c", -4) == 2, (string.count("c", 2), string.count("c", -4))
831+
# out of bounds
832+
assert string.count("a", 8) == 0
833+
assert string.count("a", -8) == 1
834+
assert string.count("b", 8) == 0
835+
assert string.count("b", -8) == 3
836+
assert string.count("c", 8) == 0
837+
assert string.count("c", -8) == 2
838+
def test_count_start_end() -> None:
839+
string = "abcbcb"
840+
assert string.count("a", 0, 4) == 1, string.count("a", 0, 4)
841+
assert string.count("b", 0, 4) == 2, string.count("b", 0, 4)
842+
assert string.count("c", 0, 4) == 1, string.count("c", 0, 4)
843+
def test_count_multi() -> None:
844+
string = "aaabbbcccbbbcccbbb"
845+
assert string.count("aaa") == 1, string.count("aaa")
846+
assert string.count("bbb") == 3, string.count("bbb")
847+
assert string.count("ccc") == 2, string.count("ccc")
848+
def test_count_multi_start() -> None:
849+
string = "aaabbbcccbbbcccbbb"
850+
assert string.count("aaa", 6) == string.count("aaa", -12) == 0, (string.count("aaa", 6), string.count("aaa", -12))
851+
assert string.count("bbb", 6) == string.count("bbb", -12) == 2, (string.count("bbb", 6), string.count("bbb", -12))
852+
assert string.count("ccc", 6) == string.count("ccc", -12) == 2, (string.count("ccc", 6), string.count("ccc", -12))
853+
# out of bounds
854+
assert string.count("aaa", 20) == 0
855+
assert string.count("aaa", -20) == 1
856+
assert string.count("bbb", 20) == 0
857+
assert string.count("bbb", -20) == 3
858+
assert string.count("ccc", 20) == 0
859+
assert string.count("ccc", -20) == 2
860+
def test_count_multi_start_end() -> None:
861+
string = "aaabbbcccbbbcccbbb"
862+
assert string.count("aaa", 0, 12) == 1, string.count("aaa", 0, 12)
863+
assert string.count("bbb", 0, 12) == 2, string.count("bbb", 0, 12)
864+
assert string.count("ccc", 0, 12) == 1, string.count("ccc", 0, 12)
865+
def test_count_emoji() -> None:
866+
string = "πŸ˜΄πŸš€Γ±πŸš€Γ±πŸš€"
867+
assert string.count("😴") == 1, string.count("😴")
868+
assert string.count("πŸš€") == 3, string.count("πŸš€")
869+
assert string.count("Γ±") == 2, string.count("Γ±")
870+
def test_count_start_emoji() -> None:
871+
string = "πŸ˜΄πŸš€Γ±πŸš€Γ±πŸš€"
872+
assert string.count("😴", 2) == string.count("😴", -4) == 0, (string.count("😴", 2), string.count("😴", -4))
873+
assert string.count("πŸš€", 2) == string.count("πŸš€", -4) == 2, (string.count("πŸš€", 2), string.count("πŸš€", -4))
874+
assert string.count("Γ±", 2) == string.count("Γ±", -4) == 2, (string.count("Γ±", 2), string.count("Γ±", -4))
875+
# Out of bounds
876+
assert string.count("😴", 8) == 0, string.count("😴", 8)
877+
assert string.count("😴", -8) == 1, string.count("😴", -8)
878+
assert string.count("πŸš€", 8) == 0, string.count("πŸš€", 8)
879+
assert string.count("πŸš€", -8) == 3, string.count("πŸš€", -8)
880+
assert string.count("Γ±", 8) == 0, string.count("Γ±", 8)
881+
assert string.count("Γ±", -8) == 2, string.count("Γ±", -8)
882+
def test_count_start_end_emoji() -> None:
883+
string = "πŸ˜΄πŸš€Γ±πŸš€Γ±πŸš€"
884+
assert string.count("😴", 0, 4) == 1, string.count("😴", 0, 4)
885+
assert string.count("πŸš€", 0, 4) == 2, string.count("πŸš€", 0, 4)
886+
assert string.count("Γ±", 0, 4) == 1, string.count("Γ±", 0, 4)
887+
def test_count_multi_emoji() -> None:
888+
string = "πŸ˜΄πŸ˜΄πŸ˜΄πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€"
889+
assert string.count("😴😴😴") == 1, string.count("😴😴😴")
890+
assert string.count("πŸš€πŸš€πŸš€") == 3, string.count("πŸš€πŸš€πŸš€")
891+
assert string.count("Γ±Γ±Γ±") == 2, string.count("Γ±Γ±Γ±")
892+
def test_count_multi_start_emoji() -> None:
893+
string = "πŸ˜΄πŸ˜΄πŸ˜΄πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€"
894+
assert string.count("😴😴😴", 6) == string.count("😴😴😴", -12) == 0, (string.count("😴😴😴", 6), string.count("😴😴😴", -12))
895+
assert string.count("πŸš€πŸš€πŸš€", 6) == string.count("πŸš€πŸš€πŸš€", -12) == 2, (string.count("πŸš€πŸš€πŸš€", 6), string.count("πŸš€πŸš€πŸš€", -12))
896+
assert string.count("Γ±Γ±Γ±", 6) == string.count("Γ±Γ±Γ±", -12) == 2, (string.count("Γ±Γ±Γ±", 6), string.count("Γ±Γ±Γ±", -12))
897+
# Out of bounds
898+
assert string.count("😴😴😴", 20) == 0, string.count("😴😴😴", 20)
899+
assert string.count("😴😴😴", -20) == 1, string.count("😴😴😴", -20)
900+
assert string.count("πŸš€πŸš€πŸš€", 20) == 0, string.count("πŸš€πŸš€πŸš€", 20)
901+
assert string.count("πŸš€πŸš€πŸš€", -20) == 3, string.count("πŸš€πŸš€πŸš€", -20)
902+
assert string.count("Γ±Γ±Γ±", 20) == 0, string.count("Γ±Γ±Γ±", 20)
903+
assert string.count("Γ±Γ±Γ±", -20) == 2, string.count("Γ±Γ±Γ±", -20)
904+
def test_count_multi_start_end_emoji() -> None:
905+
string = "πŸ˜΄πŸ˜΄πŸ˜΄πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€Γ±Γ±Γ±πŸš€πŸš€πŸš€"
906+
assert string.count("😴😴😴", 0, 12) == 1, string.count("😴😴😴", 0, 12)
907+
assert string.count("πŸš€πŸš€πŸš€", 0, 12) == 2, string.count("πŸš€πŸš€πŸš€", 0, 12)
908+
assert string.count("Γ±Γ±Γ±", 0, 12) == 1, string.count("Γ±Γ±Γ±", 0, 12)

0 commit comments

Comments
Β (0)