Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

Commit c87d1b4

Browse files
fix: Case handling in byte and bytearray methods, converting unicode to ascii array (#80)
CPython won't ever hint the converting unicode to ascii array path (since its char sequence always has characters before 0xFF). Fix handling uppercase/lowercase in byte and bytearray methods since Python ignores char points higher than 128.
1 parent e7c33cc commit c87d1b4

File tree

6 files changed

+119
-28
lines changed

6 files changed

+119
-28
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonByteArray.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ public static PythonByteArray fromIntTuple(PythonLikeTuple tuple) {
339339
}
340340

341341
public final PythonLikeTuple asIntTuple() {
342-
return IntStream.range(0, valueBuffer.limit()).mapToObj(index -> PythonBytes.BYTE_TO_INT[valueBuffer.get(index) & 0xFF])
342+
return IntStream.range(0, valueBuffer.limit())
343+
.mapToObj(index -> PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))])
343344
.collect(Collectors.toCollection(PythonLikeTuple::new));
344345
}
345346

@@ -364,7 +365,7 @@ public PythonInteger getCharAt(PythonInteger position) {
364365
throw new IndexError("position " + position + " is less than 0");
365366
}
366367

367-
return PythonBytes.BYTE_TO_INT[valueBuffer.get(index) & 0xFF];
368+
return PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))];
368369
}
369370

370371
public PythonByteArray getSubsequence(PythonSlice slice) {
@@ -435,7 +436,7 @@ public PythonByteArray repeat(PythonInteger times) {
435436

436437
public DelegatePythonIterator<PythonInteger> getIterator() {
437438
return new DelegatePythonIterator<>(IntStream.range(0, valueBuffer.limit())
438-
.mapToObj(index -> PythonBytes.BYTE_TO_INT[valueBuffer.get(index)])
439+
.mapToObj(index -> PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(index))])
439440
.iterator());
440441
}
441442

@@ -707,7 +708,7 @@ public PythonInteger pop() {
707708
if (valueBuffer.limit() == 0) {
708709
throw new IndexError("pop from empty bytearray");
709710
}
710-
PythonInteger out = PythonBytes.BYTE_TO_INT[valueBuffer.get(valueBuffer.limit() - 1) & 0xFF];
711+
PythonInteger out = PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(valueBuffer.limit() - 1))];
711712
valueBuffer.limit(valueBuffer.limit() - 1);
712713
return out;
713714
}
@@ -721,7 +722,7 @@ public PythonInteger pop(PythonInteger index) {
721722
if (indexAsInt < 0 || indexAsInt > valueBuffer.limit()) {
722723
throw new IndexError("index out of range for bytearray");
723724
}
724-
PythonInteger out = PythonBytes.BYTE_TO_INT[valueBuffer.get(indexAsInt) & 0xFF];
725+
PythonInteger out = PythonBytes.BYTE_TO_INT[Byte.toUnsignedInt(valueBuffer.get(indexAsInt))];
725726
removeBytesStartingAt(indexAsInt, 1);
726727
return out;
727728
}
@@ -1824,7 +1825,17 @@ public PythonLikeList<PythonByteArray> rightSplit(PythonNone seperator, PythonIn
18241825
}
18251826

18261827
public PythonByteArray capitalize() {
1827-
return asAsciiString().capitalize().asAsciiByteArray();
1828+
var asString = asAsciiString();
1829+
if (asString.value.isEmpty()) {
1830+
return asString.asAsciiByteArray();
1831+
}
1832+
var tail = PythonString.valueOf(asString.value.substring(1))
1833+
.withModifiedCodepoints(cp -> cp < 128 ? Character.toLowerCase(cp) : cp).value;
1834+
var head = asString.value.charAt(0);
1835+
if (head < 128) {
1836+
head = Character.toTitleCase(head);
1837+
}
1838+
return (PythonString.valueOf(head + tail)).asAsciiByteArray();
18281839
}
18291840

18301841
public PythonByteArray expandTabs() {
@@ -1874,7 +1885,8 @@ public PythonBoolean isUpper() {
18741885
}
18751886

18761887
public PythonByteArray lower() {
1877-
return asAsciiString().lower().asAsciiByteArray();
1888+
return asAsciiString().withModifiedCodepoints(
1889+
cp -> cp < 128 ? Character.toLowerCase(cp) : cp).asAsciiByteArray();
18781890
}
18791891

18801892
public PythonLikeList<PythonByteArray> splitLines() {
@@ -1892,15 +1904,17 @@ public PythonLikeList<PythonByteArray> splitLines(PythonBoolean keepEnds) {
18921904
}
18931905

18941906
public PythonByteArray swapCase() {
1895-
return asAsciiString().swapCase().asAsciiByteArray();
1907+
return asAsciiString().withModifiedCodepoints(
1908+
cp -> cp < 128 ? PythonString.CharacterCase.swapCase(cp) : cp).asAsciiByteArray();
18961909
}
18971910

18981911
public PythonByteArray title() {
1899-
return asAsciiString().title().asAsciiByteArray();
1912+
return asAsciiString().title(cp -> cp < 128).asAsciiByteArray();
19001913
}
19011914

19021915
public PythonByteArray upper() {
1903-
return asAsciiString().upper().asAsciiByteArray();
1916+
return asAsciiString().withModifiedCodepoints(
1917+
cp -> cp < 128 ? Character.toUpperCase(cp) : cp).asAsciiByteArray();
19041918
}
19051919

19061920
public PythonByteArray zfill(PythonInteger width) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonBytes.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ public static PythonBytes fromIntTuple(PythonLikeTuple tuple) {
372372
}
373373

374374
public final PythonLikeTuple asIntTuple() {
375-
return IntStream.range(0, value.length).mapToObj(index -> BYTE_TO_INT[value[index]])
375+
return IntStream.range(0, value.length).mapToObj(index -> BYTE_TO_INT[Byte.toUnsignedInt(value[index])])
376376
.collect(Collectors.toCollection(PythonLikeTuple::new));
377377
}
378378

@@ -397,7 +397,7 @@ public PythonInteger getCharAt(PythonInteger position) {
397397
throw new IndexError("position " + position + " is less than 0");
398398
}
399399

400-
return BYTE_TO_INT[value[index] & 0xFF];
400+
return BYTE_TO_INT[Byte.toUnsignedInt(value[index])];
401401
}
402402

403403
public PythonBytes getSubsequence(PythonSlice slice) {
@@ -472,7 +472,7 @@ public PythonBytes repeat(PythonInteger times) {
472472

473473
public DelegatePythonIterator<PythonInteger> getIterator() {
474474
return new DelegatePythonIterator<>(IntStream.range(0, value.length)
475-
.mapToObj(index -> BYTE_TO_INT[value[index]])
475+
.mapToObj(index -> BYTE_TO_INT[Byte.toUnsignedInt(value[index])])
476476
.iterator());
477477
}
478478

@@ -1597,7 +1597,17 @@ public PythonLikeList<PythonBytes> rightSplit(PythonNone seperator, PythonIntege
15971597
}
15981598

15991599
public PythonBytes capitalize() {
1600-
return asAsciiString().capitalize().asAsciiBytes();
1600+
var asString = asAsciiString();
1601+
if (asString.value.isEmpty()) {
1602+
return this;
1603+
}
1604+
var tail = PythonString.valueOf(asString.value.substring(1))
1605+
.withModifiedCodepoints(cp -> cp < 128 ? Character.toLowerCase(cp) : cp).value;
1606+
var head = asString.value.charAt(0);
1607+
if (head < 128) {
1608+
head = Character.toTitleCase(head);
1609+
}
1610+
return (PythonString.valueOf(head + tail)).asAsciiBytes();
16011611
}
16021612

16031613
public PythonBytes expandTabs() {
@@ -1646,7 +1656,8 @@ public PythonBoolean isUpper() {
16461656
}
16471657

16481658
public PythonBytes lower() {
1649-
return asAsciiString().lower().asAsciiBytes();
1659+
return asAsciiString().withModifiedCodepoints(
1660+
cp -> cp < 128 ? Character.toLowerCase(cp) : cp).asAsciiBytes();
16501661
}
16511662

16521663
public PythonLikeList<PythonBytes> splitLines() {
@@ -1664,15 +1675,17 @@ public PythonLikeList<PythonBytes> splitLines(PythonBoolean keepEnds) {
16641675
}
16651676

16661677
public PythonBytes swapCase() {
1667-
return asAsciiString().swapCase().asAsciiBytes();
1678+
return asAsciiString().withModifiedCodepoints(
1679+
cp -> cp < 128 ? PythonString.CharacterCase.swapCase(cp) : cp).asAsciiBytes();
16681680
}
16691681

16701682
public PythonBytes title() {
1671-
return asAsciiString().title().asAsciiBytes();
1683+
return asAsciiString().title(cp -> cp < 128).asAsciiBytes();
16721684
}
16731685

16741686
public PythonBytes upper() {
1675-
return asAsciiString().upper().asAsciiBytes();
1687+
return asAsciiString().withModifiedCodepoints(
1688+
cp -> cp < 128 ? Character.toUpperCase(cp) : cp).asAsciiBytes();
16761689
}
16771690

16781691
public PythonBytes zfill(PythonInteger width) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.List;
1212
import java.util.Map;
1313
import java.util.function.IntPredicate;
14+
import java.util.function.IntUnaryOperator;
1415
import java.util.regex.Matcher;
1516
import java.util.regex.Pattern;
1617
import java.util.stream.Collectors;
@@ -293,8 +294,9 @@ public final PythonBytes asAsciiBytes() {
293294
outIndex++;
294295
} else {
295296
out[outIndex] = (byte) ((charDatum & 0xFF00) >> 8);
297+
outIndex++;
296298
out[outIndex] = (byte) (charDatum & 0x00FF);
297-
outIndex += 2;
299+
outIndex++;
298300
}
299301
}
300302
return new PythonBytes(out);
@@ -447,6 +449,10 @@ public PythonString capitalize() {
447449
}
448450

449451
public PythonString title() {
452+
return title(ignored -> true);
453+
}
454+
455+
public PythonString title(IntPredicate predicate) {
450456
if (value.isEmpty()) {
451457
return this;
452458
}
@@ -458,10 +464,14 @@ public PythonString title() {
458464
for (int i = 0; i < length; i++) {
459465
char character = value.charAt(i);
460466

461-
if (previousIsWordBoundary) {
462-
out.append(Character.toTitleCase(character));
467+
if (predicate.test(character)) {
468+
if (previousIsWordBoundary) {
469+
out.append(Character.toTitleCase(character));
470+
} else {
471+
out.append(Character.toLowerCase(character));
472+
}
463473
} else {
464-
out.append(Character.toLowerCase(character));
474+
out.append(character);
465475
}
466476

467477
previousIsWordBoundary = !Character.isAlphabetic(character);
@@ -476,11 +486,7 @@ public PythonString casefold() {
476486
}
477487

478488
public PythonString swapCase() {
479-
return PythonString.valueOf(value.codePoints()
480-
.map(CharacterCase::swapCase)
481-
.collect(StringBuilder::new,
482-
StringBuilder::appendCodePoint, StringBuilder::append)
483-
.toString());
489+
return withModifiedCodepoints(CharacterCase::swapCase);
484490
}
485491

486492
public PythonString lower() {
@@ -491,6 +497,14 @@ public PythonString upper() {
491497
return PythonString.valueOf(value.toUpperCase());
492498
}
493499

500+
public PythonString withModifiedCodepoints(IntUnaryOperator modifier) {
501+
return PythonString.valueOf(value.codePoints()
502+
.map(modifier)
503+
.collect(StringBuilder::new,
504+
StringBuilder::appendCodePoint, StringBuilder::append)
505+
.toString());
506+
}
507+
494508
public PythonString center(PythonInteger width) {
495509
return center(width, PythonString.valueOf(" "));
496510
}
@@ -1043,7 +1057,7 @@ public PythonBoolean isUpper() {
10431057
}
10441058
}
10451059

1046-
private enum CharacterCase {
1060+
enum CharacterCase {
10471061
UNCASED,
10481062
LOWER,
10491063
UPPER;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package ai.timefold.jpyinterpreter.types;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
class PythonStringTest {
9+
10+
// Other methods are tested in test_str.py
11+
// These methods are tested here since they are internal,
12+
// and has edge cases CPython won't hit
13+
14+
@Test
15+
void asAsciiBytes() {
16+
var simple = PythonString.valueOf("abc");
17+
assertThat(simple.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { 'a', 'b', 'c' });
18+
19+
var unicode = PythonString.valueOf("π");
20+
// UTF-16 encoding
21+
assertThat(unicode.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { (byte) 0x03, (byte) 0xC0 });
22+
23+
var mixed = PythonString.valueOf("aπc");
24+
// UTF-16 encoding
25+
assertThat(mixed.asAsciiBytes().asByteArray()).isEqualTo(new byte[] { 'a', (byte) 0x03, (byte) 0xC0, 'c' });
26+
}
27+
28+
@Test
29+
void asAsciiByteArray() {
30+
var simple = PythonString.valueOf("abc");
31+
assertThat(simple.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { 'a', 'b', 'c' });
32+
33+
var unicode = PythonString.valueOf("π");
34+
// UTF-16 encoding
35+
assertThat(unicode.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { (byte) 0x03, (byte) 0xC0 });
36+
37+
var mixed = PythonString.valueOf("aπc");
38+
// UTF-16 encoding
39+
assertThat(mixed.asAsciiByteArray().asByteArray()).isEqualTo(new byte[] { 'a', (byte) 0x03, (byte) 0xC0, 'c' });
40+
}
41+
}

jpyinterpreter/tests/test_bytearray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def capitalize(tested: bytearray) -> bytearray:
547547
capitalize_verifier.verify(bytearray(b'hello world'), expected_result=bytearray(b'Hello world'))
548548
capitalize_verifier.verify(bytearray(b'Hello World'), expected_result=bytearray(b'Hello world'))
549549
capitalize_verifier.verify(bytearray(b'HELLO WORLD'), expected_result=bytearray(b'Hello world'))
550+
capitalize_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
550551

551552

552553
def test_center():
@@ -915,6 +916,7 @@ def lower(tested: bytearray) -> bytearray:
915916
lower_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
916917
lower_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
917918
lower_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
919+
lower_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
918920
lower_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
919921
lower_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
920922
lower_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))
@@ -1273,6 +1275,7 @@ def swapcase(tested: bytearray) -> bytearray:
12731275
swapcase_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
12741276
swapcase_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
12751277
swapcase_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
1278+
swapcase_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
12761279
swapcase_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
12771280
swapcase_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
12781281
swapcase_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))
@@ -1297,6 +1300,7 @@ def title(tested: bytearray) -> bytearray:
12971300
title_verifier.verify(bytearray(b'[]'), expected_result=bytearray(b'[]'))
12981301
title_verifier.verify(bytearray(b'-'), expected_result=bytearray(b'-'))
12991302
title_verifier.verify(bytearray(b'%'), expected_result=bytearray(b'%'))
1303+
title_verifier.verify(bytearray('π'.encode()), expected_result=bytearray('π'.encode()))
13001304
title_verifier.verify(bytearray(b'\n'), expected_result=bytearray(b'\n'))
13011305
title_verifier.verify(bytearray(b'\t'), expected_result=bytearray(b'\t'))
13021306
title_verifier.verify(bytearray(b' '), expected_result=bytearray(b' '))

jpyinterpreter/tests/test_bytes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def capitalize(tested: bytes) -> bytes:
279279
capitalize_verifier.verify(b'hello world', expected_result=b'Hello world')
280280
capitalize_verifier.verify(b'Hello World', expected_result=b'Hello world')
281281
capitalize_verifier.verify(b'HELLO WORLD', expected_result=b'Hello world')
282+
capitalize_verifier.verify('π'.encode(), expected_result='π'.encode())
282283

283284

284285
def test_center():
@@ -647,6 +648,7 @@ def lower(tested: bytes) -> bytes:
647648
lower_verifier.verify(b'[]', expected_result=b'[]')
648649
lower_verifier.verify(b'-', expected_result=b'-')
649650
lower_verifier.verify(b'%', expected_result=b'%')
651+
lower_verifier.verify('π'.encode(), expected_result='π'.encode())
650652
lower_verifier.verify(b'\n', expected_result=b'\n')
651653
lower_verifier.verify(b'\t', expected_result=b'\t')
652654
lower_verifier.verify(b' ', expected_result=b' ')
@@ -1005,6 +1007,7 @@ def swapcase(tested: bytes) -> bytes:
10051007
swapcase_verifier.verify(b'[]', expected_result=b'[]')
10061008
swapcase_verifier.verify(b'-', expected_result=b'-')
10071009
swapcase_verifier.verify(b'%', expected_result=b'%')
1010+
swapcase_verifier.verify('π'.encode(), expected_result='π'.encode())
10081011
swapcase_verifier.verify(b'\n', expected_result=b'\n')
10091012
swapcase_verifier.verify(b'\t', expected_result=b'\t')
10101013
swapcase_verifier.verify(b' ', expected_result=b' ')
@@ -1029,6 +1032,7 @@ def title(tested: bytes) -> bytes:
10291032
title_verifier.verify(b'[]', expected_result=b'[]')
10301033
title_verifier.verify(b'-', expected_result=b'-')
10311034
title_verifier.verify(b'%', expected_result=b'%')
1035+
title_verifier.verify('π'.encode(), expected_result='π'.encode())
10321036
title_verifier.verify(b'\n', expected_result=b'\n')
10331037
title_verifier.verify(b'\t', expected_result=b'\t')
10341038
title_verifier.verify(b' ', expected_result=b' ')
@@ -1061,6 +1065,7 @@ def upper(tested: bytes) -> bytes:
10611065
upper_verifier.verify(b'[]', expected_result=b'[]')
10621066
upper_verifier.verify(b'-', expected_result=b'-')
10631067
upper_verifier.verify(b'%', expected_result=b'%')
1068+
upper_verifier.verify('π'.encode(), expected_result='π'.encode())
10641069
upper_verifier.verify(b'\n', expected_result=b'\n')
10651070
upper_verifier.verify(b'\t', expected_result=b'\t')
10661071
upper_verifier.verify(b' ', expected_result=b' ')

0 commit comments

Comments
 (0)