Skip to content

Commit 31797bd

Browse files
fix: Case handling in byte and bytearray methods, converting unicode to ascii array
CPython won't ever hint the converting unicode to ascii array path (since its char sequence always has characters before 0xFF). Fix handling uppercase/lowercase in byte and bytearray methods since Python ignores char points higher than 128.
1 parent 3942327 commit 31797bd

File tree

6 files changed

+139
-27
lines changed

6 files changed

+139
-27
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonByteArray.java

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ public static PythonByteArray fromIntTuple(PythonLikeTuple tuple) {
339339
}
340340

341341
public final PythonLikeTuple asIntTuple() {
342-
return IntStream.range(0, valueBuffer.limit()).mapToObj(index -> PythonBytes.BYTE_TO_INT[valueBuffer.get(index) & 0xFF])
342+
return IntStream.range(0, valueBuffer.limit()).mapToObj(index -> PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))])
343343
.collect(Collectors.toCollection(PythonLikeTuple::new));
344344
}
345345

@@ -364,7 +364,7 @@ public PythonInteger getCharAt(PythonInteger position) {
364364
throw new IndexError("position " + position + " is less than 0");
365365
}
366366

367-
return PythonBytes.BYTE_TO_INT[valueBuffer.get(index) & 0xFF];
367+
return PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))];
368368
}
369369

370370
public PythonByteArray getSubsequence(PythonSlice slice) {
@@ -435,7 +435,7 @@ public PythonByteArray repeat(PythonInteger times) {
435435

436436
public DelegatePythonIterator<PythonInteger> getIterator() {
437437
return new DelegatePythonIterator<>(IntStream.range(0, valueBuffer.limit())
438-
.mapToObj(index -> PythonBytes.BYTE_TO_INT[valueBuffer.get(index)])
438+
.mapToObj(index -> PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))])
439439
.iterator());
440440
}
441441

@@ -707,7 +707,7 @@ public PythonInteger pop() {
707707
if (valueBuffer.limit() == 0) {
708708
throw new IndexError("pop from empty bytearray");
709709
}
710-
PythonInteger out = PythonBytes.BYTE_TO_INT[valueBuffer.get(valueBuffer.limit() - 1) & 0xFF];
710+
PythonInteger out = PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(valueBuffer.limit() - 1))];
711711
valueBuffer.limit(valueBuffer.limit() - 1);
712712
return out;
713713
}
@@ -721,7 +721,7 @@ public PythonInteger pop(PythonInteger index) {
721721
if (indexAsInt < 0 || indexAsInt > valueBuffer.limit()) {
722722
throw new IndexError("index out of range for bytearray");
723723
}
724-
PythonInteger out = PythonBytes.BYTE_TO_INT[valueBuffer.get(indexAsInt) & 0xFF];
724+
PythonInteger out = PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(indexAsInt))];
725725
removeBytesStartingAt(indexAsInt, 1);
726726
return out;
727727
}
@@ -1824,7 +1824,17 @@ public PythonLikeList<PythonByteArray> rightSplit(PythonNone seperator, PythonIn
18241824
}
18251825

18261826
public PythonByteArray capitalize() {
1827-
return asAsciiString().capitalize().asAsciiByteArray();
1827+
var asString = asAsciiString();
1828+
if (asString.value.isEmpty()) {
1829+
return asString.asAsciiByteArray();
1830+
}
1831+
var tail = PythonString.valueOf(asString.value.substring(1))
1832+
.withModifiedCodepoints(cp -> cp < 128? Character.toLowerCase(cp) : cp).value;
1833+
var head = asString.value.charAt(0);
1834+
if (head < 128) {
1835+
head = Character.toTitleCase(head);
1836+
}
1837+
return (PythonString.valueOf(head + tail)).asAsciiByteArray();
18281838
}
18291839

18301840
public PythonByteArray expandTabs() {
@@ -1874,7 +1884,9 @@ public PythonBoolean isUpper() {
18741884
}
18751885

18761886
public PythonByteArray lower() {
1877-
return asAsciiString().lower().asAsciiByteArray();
1887+
return asAsciiString().withModifiedCodepoints(
1888+
cp -> cp < 128? Character.toLowerCase(cp) : cp
1889+
).asAsciiByteArray();
18781890
}
18791891

18801892
public PythonLikeList<PythonByteArray> splitLines() {
@@ -1892,15 +1904,27 @@ public PythonLikeList<PythonByteArray> splitLines(PythonBoolean keepEnds) {
18921904
}
18931905

18941906
public PythonByteArray swapCase() {
1895-
return asAsciiString().swapCase().asAsciiByteArray();
1907+
return asAsciiString().withModifiedCodepoints(
1908+
cp -> {
1909+
if (cp >= 128) {
1910+
return cp;
1911+
}
1912+
if (Character.isLowerCase(cp)) {
1913+
return Character.toUpperCase(cp);
1914+
}
1915+
return Character.toLowerCase(cp);
1916+
}
1917+
).asAsciiByteArray();
18961918
}
18971919

18981920
public PythonByteArray title() {
1899-
return asAsciiString().title().asAsciiByteArray();
1921+
return asAsciiString().title(cp -> cp < 128).asAsciiByteArray();
19001922
}
19011923

19021924
public PythonByteArray upper() {
1903-
return asAsciiString().upper().asAsciiByteArray();
1925+
return asAsciiString().withModifiedCodepoints(
1926+
cp -> cp < 128? Character.toUpperCase(cp) : cp
1927+
).asAsciiByteArray();
19041928
}
19051929

19061930
public PythonByteArray zfill(PythonInteger width) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonBytes.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ public static PythonBytes fromIntTuple(PythonLikeTuple tuple) {
372372
}
373373

374374
public final PythonLikeTuple asIntTuple() {
375-
return IntStream.range(0, value.length).mapToObj(index -> BYTE_TO_INT[value[index]])
375+
return IntStream.range(0, value.length).mapToObj(index -> BYTE_TO_INT[Byte.toUnsignedInt(value[index])])
376376
.collect(Collectors.toCollection(PythonLikeTuple::new));
377377
}
378378

@@ -397,7 +397,7 @@ public PythonInteger getCharAt(PythonInteger position) {
397397
throw new IndexError("position " + position + " is less than 0");
398398
}
399399

400-
return BYTE_TO_INT[value[index] & 0xFF];
400+
return BYTE_TO_INT[Byte.toUnsignedInt(value[index])];
401401
}
402402

403403
public PythonBytes getSubsequence(PythonSlice slice) {
@@ -472,7 +472,7 @@ public PythonBytes repeat(PythonInteger times) {
472472

473473
public DelegatePythonIterator<PythonInteger> getIterator() {
474474
return new DelegatePythonIterator<>(IntStream.range(0, value.length)
475-
.mapToObj(index -> BYTE_TO_INT[value[index]])
475+
.mapToObj(index -> BYTE_TO_INT[Byte.toUnsignedInt(value[index])])
476476
.iterator());
477477
}
478478

@@ -1597,7 +1597,17 @@ public PythonLikeList<PythonBytes> rightSplit(PythonNone seperator, PythonIntege
15971597
}
15981598

15991599
public PythonBytes capitalize() {
1600-
return asAsciiString().capitalize().asAsciiBytes();
1600+
var asString = asAsciiString();
1601+
if (asString.value.isEmpty()) {
1602+
return this;
1603+
}
1604+
var tail = PythonString.valueOf(asString.value.substring(1))
1605+
.withModifiedCodepoints(cp -> cp < 128? Character.toLowerCase(cp) : cp).value;
1606+
var head = asString.value.charAt(0);
1607+
if (head < 128) {
1608+
head = Character.toTitleCase(head);
1609+
}
1610+
return (PythonString.valueOf(head + tail)).asAsciiBytes();
16011611
}
16021612

16031613
public PythonBytes expandTabs() {
@@ -1646,7 +1656,9 @@ public PythonBoolean isUpper() {
16461656
}
16471657

16481658
public PythonBytes lower() {
1649-
return asAsciiString().lower().asAsciiBytes();
1659+
return asAsciiString().withModifiedCodepoints(
1660+
cp -> cp < 128? Character.toLowerCase(cp) : cp
1661+
).asAsciiBytes();
16501662
}
16511663

16521664
public PythonLikeList<PythonBytes> splitLines() {
@@ -1664,15 +1676,27 @@ public PythonLikeList<PythonBytes> splitLines(PythonBoolean keepEnds) {
16641676
}
16651677

16661678
public PythonBytes swapCase() {
1667-
return asAsciiString().swapCase().asAsciiBytes();
1679+
return asAsciiString().withModifiedCodepoints(
1680+
cp -> {
1681+
if (cp >= 128) {
1682+
return cp;
1683+
}
1684+
if (Character.isLowerCase(cp)) {
1685+
return Character.toUpperCase(cp);
1686+
}
1687+
return Character.toLowerCase(cp);
1688+
}
1689+
).asAsciiBytes();
16681690
}
16691691

16701692
public PythonBytes title() {
1671-
return asAsciiString().title().asAsciiBytes();
1693+
return asAsciiString().title(cp -> cp < 128).asAsciiBytes();
16721694
}
16731695

16741696
public PythonBytes upper() {
1675-
return asAsciiString().upper().asAsciiBytes();
1697+
return asAsciiString().withModifiedCodepoints(
1698+
cp -> cp < 128? Character.toUpperCase(cp) : cp
1699+
).asAsciiBytes();
16761700
}
16771701

16781702
public PythonBytes zfill(PythonInteger width) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.List;
1212
import java.util.Map;
1313
import java.util.function.IntPredicate;
14+
import java.util.function.IntUnaryOperator;
1415
import java.util.regex.Matcher;
1516
import java.util.regex.Pattern;
1617
import java.util.stream.Collectors;
@@ -293,8 +294,9 @@ public final PythonBytes asAsciiBytes() {
293294
outIndex++;
294295
} else {
295296
out[outIndex] = (byte) ((charDatum & 0xFF00) >> 8);
297+
outIndex++;
296298
out[outIndex] = (byte) (charDatum & 0x00FF);
297-
outIndex += 2;
299+
outIndex++;
298300
}
299301
}
300302
return new PythonBytes(out);
@@ -447,6 +449,10 @@ public PythonString capitalize() {
447449
}
448450

449451
public PythonString title() {
452+
return title(ignored -> true);
453+
}
454+
455+
public PythonString title(IntPredicate predicate) {
450456
if (value.isEmpty()) {
451457
return this;
452458
}
@@ -458,10 +464,14 @@ public PythonString title() {
458464
for (int i = 0; i < length; i++) {
459465
char character = value.charAt(i);
460466

461-
if (previousIsWordBoundary) {
462-
out.append(Character.toTitleCase(character));
467+
if (predicate.test(character)) {
468+
if (previousIsWordBoundary) {
469+
out.append(Character.toTitleCase(character));
470+
} else {
471+
out.append(Character.toLowerCase(character));
472+
}
463473
} else {
464-
out.append(Character.toLowerCase(character));
474+
out.append(character);
465475
}
466476

467477
previousIsWordBoundary = !Character.isAlphabetic(character);
@@ -476,11 +486,7 @@ public PythonString casefold() {
476486
}
477487

478488
public PythonString swapCase() {
479-
return PythonString.valueOf(value.codePoints()
480-
.map(CharacterCase::swapCase)
481-
.collect(StringBuilder::new,
482-
StringBuilder::appendCodePoint, StringBuilder::append)
483-
.toString());
489+
return withModifiedCodepoints(CharacterCase::swapCase);
484490
}
485491

486492
public PythonString lower() {
@@ -491,6 +497,14 @@ public PythonString upper() {
491497
return PythonString.valueOf(value.toUpperCase());
492498
}
493499

500+
public PythonString withModifiedCodepoints(IntUnaryOperator modifier) {
501+
return PythonString.valueOf(value.codePoints()
502+
.map(modifier)
503+
.collect(StringBuilder::new,
504+
StringBuilder::appendCodePoint, StringBuilder::append)
505+
.toString());
506+
}
507+
494508
public PythonString center(PythonInteger width) {
495509
return center(width, PythonString.valueOf(" "));
496510
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package ai.timefold.jpyinterpreter.types;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import static org.assertj.core.api.Assertions.assertThat;
6+
import static org.junit.jupiter.api.Assertions.*;
7+
8+
class PythonStringTest {
9+
10+
// Other methods are tested in test_str.py
11+
// These methods are tested here since they are internal,
12+
// and has edge cases CPython won't hit
13+
14+
@Test
15+
void asAsciiBytes() {
16+
var simple = PythonString.valueOf("abc");
17+
assertThat(simple.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { 'a', 'b', 'c'});
18+
19+
var unicode = PythonString.valueOf("π");
20+
// UTF-16 encoding
21+
assertThat(unicode.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { (byte) 0x03, (byte) 0xC0 });
22+
23+
var mixed = PythonString.valueOf("aπc");
24+
// UTF-16 encoding
25+
assertThat(mixed.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { 'a', (byte) 0x03, (byte) 0xC0, 'c' });
26+
}
27+
28+
@Test
29+
void asAsciiByteArray() {
30+
var simple = PythonString.valueOf("abc");
31+
assertThat(simple.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { 'a', 'b', 'c'});
32+
33+
var unicode = PythonString.valueOf("π");
34+
// UTF-16 encoding
35+
assertThat(unicode.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { (byte) 0x03, (byte) 0xC0 });
36+
37+
var mixed = PythonString.valueOf("aπc");
38+
// UTF-16 encoding
39+
assertThat(mixed.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { 'a', (byte) 0x03, (byte) 0xC0, 'c' });
40+
}
41+
}

jpyinterpreter/tests/test_bytearray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def capitalize(tested: bytearray) -> bytearray:
547547
capitalize_verifier.verify(bytearray(b'hello world'), expected_result=bytearray(b'Hello world'))
548548
capitalize_verifier.verify(bytearray(b'Hello World'), expected_result=bytearray(b'Hello world'))
549549
capitalize_verifier.verify(bytearray(b'HELLO WORLD'), expected_result=bytearray(b'Hello world'))
550+
capitalize_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
550551

551552

552553
def test_center():
@@ -915,6 +916,7 @@ def lower(tested: bytearray) -> bytearray:
915916
lower_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
916917
lower_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
917918
lower_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
919+
lower_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
918920
lower_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
919921
lower_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
920922
lower_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))
@@ -1273,6 +1275,7 @@ def swapcase(tested: bytearray) -> bytearray:
12731275
swapcase_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
12741276
swapcase_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
12751277
swapcase_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
1278+
swapcase_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
12761279
swapcase_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
12771280
swapcase_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
12781281
swapcase_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))
@@ -1297,6 +1300,7 @@ def title(tested: bytearray) -> bytearray:
12971300
title_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
12981301
title_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
12991302
title_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
1303+
title_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
13001304
title_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
13011305
title_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
13021306
title_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))

jpyinterpreter/tests/test_bytes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def capitalize(tested: bytes) -> bytes:
279279
capitalize_verifier.verify(b'hello world', expected_result=b'Hello world')
280280
capitalize_verifier.verify(b'Hello World', expected_result=b'Hello world')
281281
capitalize_verifier.verify(b'HELLO WORLD', expected_result=b'Hello world')
282+
capitalize_verifier.verify('π'.encode(), expected_result='π'.encode())
282283

283284

284285
def test_center():
@@ -647,6 +648,7 @@ def lower(tested: bytes) -> bytes:
647648
lower_verifier.verify(b'[]', expected_result=b'[]')
648649
lower_verifier.verify(b'-', expected_result=b'-')
649650
lower_verifier.verify(b'%', expected_result=b'%')
651+
lower_verifier.verify('π'.encode(), expected_result='π'.encode())
650652
lower_verifier.verify(b'\n', expected_result=b'\n')
651653
lower_verifier.verify(b'\t', expected_result=b'\t')
652654
lower_verifier.verify(b' ', expected_result=b' ')
@@ -1005,6 +1007,7 @@ def swapcase(tested: bytes) -> bytes:
10051007
swapcase_verifier.verify(b'[]', expected_result=b'[]')
10061008
swapcase_verifier.verify(b'-', expected_result=b'-')
10071009
swapcase_verifier.verify(b'%', expected_result=b'%')
1010+
swapcase_verifier.verify('π'.encode(), expected_result='π'.encode())
10081011
swapcase_verifier.verify(b'\n', expected_result=b'\n')
10091012
swapcase_verifier.verify(b'\t', expected_result=b'\t')
10101013
swapcase_verifier.verify(b' ', expected_result=b' ')
@@ -1029,6 +1032,7 @@ def title(tested: bytes) -> bytes:
10291032
title_verifier.verify(b'[]', expected_result=b'[]')
10301033
title_verifier.verify(b'-', expected_result=b'-')
10311034
title_verifier.verify(b'%', expected_result=b'%')
1035+
title_verifier.verify('π'.encode(), expected_result='π'.encode())
10321036
title_verifier.verify(b'\n', expected_result=b'\n')
10331037
title_verifier.verify(b'\t', expected_result=b'\t')
10341038
title_verifier.verify(b' ', expected_result=b' ')
@@ -1061,6 +1065,7 @@ def upper(tested: bytes) -> bytes:
10611065
upper_verifier.verify(b'[]', expected_result=b'[]')
10621066
upper_verifier.verify(b'-', expected_result=b'-')
10631067
upper_verifier.verify(b'%', expected_result=b'%')
1068+
upper_verifier.verify('π'.encode(), expected_result='π'.encode())
10641069
upper_verifier.verify(b'\n', expected_result=b'\n')
10651070
upper_verifier.verify(b'\t', expected_result=b'\t')
10661071
upper_verifier.verify(b' ', expected_result=b' ')

0 commit comments

Comments
 (0)