Skip to content

Commit 30a13b9

Browse files
feat: Add support for nearby selection (#32)
- Add automatic type coercion between int and float to prevent ClassCastExceptions when a function expecting float is given int and vice-versa - Add register_java_class to set field holding the generated class and register the class in the SolverConfig classloader - Rename field holding the generated class from __timefold_java_class to _timefold_java_class to prevent name-mangling - Remove unused code
1 parent 158aca7 commit 30a13b9

File tree

14 files changed

+172
-81
lines changed

14 files changed

+172
-81
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor;
2424
import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor;
2525
import ai.timefold.jpyinterpreter.implementors.JavaInterfaceImplementor;
26+
import ai.timefold.jpyinterpreter.implementors.JavaPythonTypeConversionImplementor;
2627
import ai.timefold.jpyinterpreter.implementors.PythonConstantsImplementor;
2728
import ai.timefold.jpyinterpreter.opcodes.AbstractOpcode;
2829
import ai.timefold.jpyinterpreter.opcodes.Opcode;
@@ -1096,6 +1097,13 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter
10961097
getUnwrappedJavaObject(methodVisitor, type);
10971098
typeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
10981099
} else {
1100+
methodVisitor.visitLdcInsn(Type.getType(type.getJavaTypeDescriptor()));
1101+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC,
1102+
Type.getInternalName(JavaPythonTypeConversionImplementor.class),
1103+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
1104+
Type.getType(PythonLikeObject.class),
1105+
Type.getType(Class.class)),
1106+
false);
10991107
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type.getJavaTypeInternalName());
11001108
}
11011109
methodVisitor.visitFieldInsn(Opcodes.PUTFIELD, classInternalName, getJavaFieldName(field),
@@ -1231,6 +1239,12 @@ public static void createReadFromCPythonReference(ClassWriter classWriter, Strin
12311239
}
12321240

12331241
var attributeType = attributeNameToType.get(field);
1242+
methodVisitor.visitLdcInsn(Type.getType(attributeType.getJavaTypeDescriptor()));
1243+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class),
1244+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
1245+
Type.getType(PythonLikeObject.class),
1246+
Type.getType(Class.class)),
1247+
false);
12341248
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, attributeNameToType.get(field).getJavaTypeInternalName());
12351249

12361250
if (attributeType.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/JavaPythonTypeConversionImplementor.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ai.timefold.jpyinterpreter.PythonLikeObject;
1515
import ai.timefold.jpyinterpreter.StackMetadata;
1616
import ai.timefold.jpyinterpreter.types.BuiltinTypes;
17+
import ai.timefold.jpyinterpreter.types.Coercible;
1718
import ai.timefold.jpyinterpreter.types.PythonByteArray;
1819
import ai.timefold.jpyinterpreter.types.PythonBytes;
1920
import ai.timefold.jpyinterpreter.types.PythonCode;
@@ -447,6 +448,31 @@ public static void returnValue(MethodVisitor methodVisitor, MethodDescriptor met
447448
methodVisitor.visitInsn(Opcodes.ARETURN);
448449
}
449450

451+
/**
452+
* Coerce a value to a given type
453+
*/
454+
public static <T> T coerceToType(PythonLikeObject value, Class<T> type) {
455+
if (value == null) {
456+
return null;
457+
}
458+
459+
if (type.isAssignableFrom(value.getClass())) {
460+
return (T) value;
461+
}
462+
463+
if (value instanceof Coercible coercible) {
464+
var out = coercible.coerce(type);
465+
if (out == null) {
466+
throw new TypeError("%s cannot be coerced to %s."
467+
.formatted(value.$getType(), type));
468+
}
469+
return out;
470+
}
471+
472+
throw new TypeError("%s cannot be coerced to %s."
473+
.formatted(value.$getType(), type));
474+
}
475+
450476
/**
451477
* Convert the {@code parameterIndex} Java parameter to its Python equivalent and store it into
452478
* the corresponding Python parameter local variable slot.

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/KnownCallImplementor.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ public static void callMethod(PythonFunctionSignature pythonFunctionSignature, M
113113
// Now load and typecheck the local variables
114114
for (int i = 0; i < Math.min(specPositionalArgumentCount, argumentCount); i++) {
115115
localVariableHelper.readTemp(methodVisitor, Type.getType(PythonLikeObject.class), argumentLocals[i]);
116+
methodVisitor.visitLdcInsn(Type.getType(pythonFunctionSignature.getArgumentSpec().getArgumentType(i)));
117+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class),
118+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
119+
Type.getType(PythonLikeObject.class),
120+
Type.getType(Class.class)),
121+
false);
116122
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST,
117123
Type.getInternalName(pythonFunctionSignature.getArgumentSpec().getArgumentType(i)));
118124
}
@@ -279,6 +285,12 @@ public static void callPython311andAbove(PythonFunctionSignature pythonFunctionS
279285
// Load arguments in proper order and typecast them
280286
for (int i = 0; i < specTotalArgumentCount; i++) {
281287
localVariableHelper.readTemp(methodVisitor, Type.getType(PythonLikeObject.class), argumentLocals[i]);
288+
methodVisitor.visitLdcInsn(Type.getType(pythonFunctionSignature.getArgumentSpec().getArgumentType(i)));
289+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class),
290+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
291+
Type.getType(PythonLikeObject.class),
292+
Type.getType(Class.class)),
293+
false);
282294
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST,
283295
Type.getInternalName(pythonFunctionSignature.getArgumentSpec().getArgumentType(i)));
284296
}
@@ -480,6 +492,12 @@ public static void callUnpackListAndMap(Class<?> defaultArgumentHolderClass, Met
480492
methodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(List.class),
481493
"get", Type.getMethodDescriptor(Type.getType(Object.class), Type.INT_TYPE),
482494
true);
495+
methodVisitor.visitLdcInsn(descriptorParameterTypes[i]);
496+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class),
497+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
498+
Type.getType(PythonLikeObject.class),
499+
Type.getType(Class.class)),
500+
false);
483501
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, descriptorParameterTypes[i].getInternalName());
484502
methodVisitor.visitInsn(Opcodes.SWAP);
485503
}

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/ObjectImplementor.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ public static void setAttribute(FunctionMetadata functionMetadata, MethodVisitor
140140
FieldDescriptor fieldDescriptor = maybeFieldDescriptor.get();
141141
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, fieldDescriptor.declaringClassInternalName());
142142
StackManipulationImplementor.swap(methodVisitor);
143+
methodVisitor.visitLdcInsn(Type.getType(fieldDescriptor.fieldPythonLikeType().getJavaTypeDescriptor()));
144+
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class),
145+
"coerceToType", Type.getMethodDescriptor(Type.getType(Object.class),
146+
Type.getType(PythonLikeObject.class),
147+
Type.getType(Class.class)),
148+
false);
143149
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, fieldDescriptor.fieldPythonLikeType().getJavaTypeInternalName());
144150
if (fieldDescriptor.isJavaType()) {
145151
// Need to unwrap the object
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package ai.timefold.jpyinterpreter.types;
2+
3+
public interface Coercible {
4+
<T> T coerce(Class<T> targetType);
5+
}

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonFloat.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import ai.timefold.jpyinterpreter.PythonUnaryOperator;
1616
import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject;
1717
import ai.timefold.jpyinterpreter.types.BuiltinTypes;
18+
import ai.timefold.jpyinterpreter.types.Coercible;
1819
import ai.timefold.jpyinterpreter.types.NotImplemented;
1920
import ai.timefold.jpyinterpreter.types.PythonLikeComparable;
2021
import ai.timefold.jpyinterpreter.types.PythonLikeFunction;
@@ -29,7 +30,8 @@
2930
import ai.timefold.jpyinterpreter.util.StringFormatter;
3031
import ai.timefold.solver.core.impl.domain.solution.cloner.PlanningImmutable;
3132

32-
public class PythonFloat extends AbstractPythonLikeObject implements PythonNumber, PlanningImmutable {
33+
public class PythonFloat extends AbstractPythonLikeObject implements PythonNumber, PlanningImmutable,
34+
Coercible {
3335
public final double value;
3436

3537
static {
@@ -793,4 +795,12 @@ private DecimalFormat getNumberFormat(DefaultFormatSpec formatSpec) {
793795
StringFormatter.align(out, formatSpec, DefaultFormatSpec.AlignmentOption.RIGHT_ALIGN);
794796
return PythonString.valueOf(out.toString());
795797
}
798+
799+
@Override
800+
public <T> T coerce(Class<T> targetType) {
801+
if (targetType.equals(PythonInteger.class)) {
802+
return (T) PythonInteger.valueOf((long) value);
803+
}
804+
return null;
805+
}
796806
}

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/numeric/PythonInteger.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import ai.timefold.jpyinterpreter.PythonUnaryOperator;
1414
import ai.timefold.jpyinterpreter.types.AbstractPythonLikeObject;
1515
import ai.timefold.jpyinterpreter.types.BuiltinTypes;
16+
import ai.timefold.jpyinterpreter.types.Coercible;
1617
import ai.timefold.jpyinterpreter.types.NotImplemented;
1718
import ai.timefold.jpyinterpreter.types.PythonLikeFunction;
1819
import ai.timefold.jpyinterpreter.types.PythonLikeType;
@@ -26,7 +27,8 @@
2627
import ai.timefold.jpyinterpreter.util.StringFormatter;
2728
import ai.timefold.solver.core.impl.domain.solution.cloner.PlanningImmutable;
2829

29-
public class PythonInteger extends AbstractPythonLikeObject implements PythonNumber, PlanningImmutable {
30+
public class PythonInteger extends AbstractPythonLikeObject implements PythonNumber,
31+
PlanningImmutable, Coercible {
3032
private static final BigInteger MIN_BYTE = BigInteger.valueOf(0);
3133
private static final BigInteger MAX_BYTE = BigInteger.valueOf(255);
3234

@@ -840,4 +842,12 @@ public PythonString asString() {
840842

841843
return PythonString.valueOf(out.toString());
842844
}
845+
846+
@Override
847+
public <T> T coerce(Class<T> targetType) {
848+
if (targetType.equals(PythonFloat.class)) {
849+
return (T) PythonFloat.valueOf(value.doubleValue());
850+
}
851+
return null;
852+
}
843853
}

jpyinterpreter/tests/test_classes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@ def create_instance(x: int) -> A:
2424
verifier.verify(3, expected_result=A(3))
2525

2626

27+
def test_type_coercing():
28+
class A:
29+
value: float
30+
31+
def __init__(self, value):
32+
self.value = value
33+
34+
def __eq__(self, other):
35+
if not isinstance(other, A):
36+
return False
37+
return self.value == other.value
38+
39+
def create_instance(x: int) -> A:
40+
return A(x)
41+
42+
verifier = verifier_for(create_instance)
43+
44+
verifier.verify(1, expected_result=A(1))
45+
verifier.verify(2, expected_result=A(2))
46+
verifier.verify(3, expected_result=A(3))
47+
48+
2749
def test_deleted_field():
2850
class A:
2951
value: int

tests/test_solver_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_load_from_solver_config_file():
5555
assert entity_class_list.size() == 1
5656
assert entity_class_list.get(0) == get_java_type_for_python_type(Entity).getJavaClass()
5757
assert solver_config.getScoreDirectorFactoryConfig().getConstraintProviderClass() == \
58-
my_constraints.__timefold_java_class # noqa
58+
my_constraints._timefold_java_class # noqa
5959
assert solver_config.getTerminationConfig().getBestScoreLimit() == '0hard/0soft'
6060

6161

tests/test_user_error.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,12 @@ def test_missing_enterprise():
111111
solver_config = SolverConfig(
112112
move_thread_count=MoveThreadCount.AUTO
113113
)._to_java_solver_config()
114+
115+
@nearby_distance_meter
116+
def my_distance_meter(entity: Entity, value: str) -> float:
117+
return 0.0
118+
119+
with pytest.raises(RequiresEnterpriseError, match=re.escape('nearby selection')):
120+
solver_config = SolverConfig(
121+
nearby_distance_meter_function=my_distance_meter
122+
)._to_java_solver_config()

0 commit comments

Comments
 (0)