Skip to content

Commit d6ad7a1

Browse files
fix: Generate getters/setters that replace null with None (or vice versa) (#29)
- Timefold considers None to be initialized, although it represents an uninitialized value - Instead of annotating the fields directly, we could annotate getters/setters instead - If None is assignable to the field type, the getter is `return this.field != None? this.field : null`; otherwise it's `return this.field` - If None is assignable to the field type, the setter is `this.field =(value != null)? value : None`; otherwise it's `this.field = value` - Assign the `__class__` field of a new class to itself to match CPython.
1 parent 850ff44 commit d6ad7a1

File tree

5 files changed

+133
-63
lines changed

5 files changed

+133
-63
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java

Lines changed: 105 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor;
2424
import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor;
2525
import ai.timefold.jpyinterpreter.implementors.JavaInterfaceImplementor;
26+
import ai.timefold.jpyinterpreter.implementors.PythonConstantsImplementor;
2627
import ai.timefold.jpyinterpreter.opcodes.AbstractOpcode;
2728
import ai.timefold.jpyinterpreter.opcodes.Opcode;
2829
import ai.timefold.jpyinterpreter.opcodes.SelfOpcodeWithoutSource;
@@ -205,14 +206,14 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
205206
attributeNameToTypeMap.put(attributeName, type);
206207
FieldVisitor fieldVisitor;
207208
String javaFieldTypeDescriptor;
209+
String signature = null;
208210
boolean isJavaType;
209211
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
210212
javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
211213
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
212214
null, null);
213215
isJavaType = true;
214216
} else {
215-
String signature = null;
216217
if (typeHint.genericArgs() != null) {
217218
var signatureWriter = new SignatureWriter();
218219
visitSignature(typeHint, signatureWriter);
@@ -223,10 +224,12 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
223224
signature, null);
224225
isJavaType = false;
225226
}
226-
for (var annotation : typeHint.annotationList()) {
227-
annotation.addAnnotationTo(fieldVisitor);
228-
}
229227
fieldVisitor.visitEnd();
228+
createJavaGetterSetter(classWriter, preparedClassInfo,
229+
attributeName,
230+
Type.getType(javaFieldTypeDescriptor),
231+
signature,
232+
typeHint);
230233
FieldDescriptor fieldDescriptor =
231234
new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName,
232235
javaFieldTypeDescriptor, type, true, isJavaType);
@@ -761,6 +764,85 @@ private static PythonLikeFunction createConstructor(String classInternalName,
761764
}
762765
}
763766

767+
private static void createJavaGetterSetter(ClassWriter classWriter,
768+
PreparedClassInfo preparedClassInfo,
769+
String attributeName, Type attributeType,
770+
String signature,
771+
TypeHint typeHint) {
772+
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
773+
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
774+
}
775+
776+
private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
777+
Type attributeType, String signature, TypeHint typeHint) {
778+
var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
779+
if (signature != null) {
780+
signature = "()" + signature;
781+
}
782+
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType),
783+
signature, null);
784+
var maxStack = 1;
785+
786+
for (var annotation : typeHint.annotationList()) {
787+
annotation.addAnnotationTo(getterVisitor);
788+
}
789+
790+
getterVisitor.visitCode();
791+
getterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
792+
getterVisitor.visitFieldInsn(Opcodes.GETFIELD, preparedClassInfo.classInternalName,
793+
attributeName, attributeType.getDescriptor());
794+
if (typeHint.type().isInstance(PythonNone.INSTANCE)) {
795+
maxStack = 3;
796+
getterVisitor.visitInsn(Opcodes.DUP);
797+
PythonConstantsImplementor.loadNone(getterVisitor);
798+
Label returnLabel = new Label();
799+
getterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, returnLabel);
800+
// field is None, so we want Java to see it as null
801+
getterVisitor.visitInsn(Opcodes.POP);
802+
getterVisitor.visitInsn(Opcodes.ACONST_NULL);
803+
getterVisitor.visitLabel(returnLabel);
804+
// If branch is taken, stack is field
805+
// If branch is not taken, stack is null
806+
}
807+
getterVisitor.visitInsn(Opcodes.ARETURN);
808+
getterVisitor.visitMaxs(maxStack, 0);
809+
getterVisitor.visitEnd();
810+
}
811+
812+
private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
813+
Type attributeType, String signature, TypeHint typeHint) {
814+
var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
815+
if (signature != null) {
816+
signature = "(" + signature + ")V";
817+
}
818+
var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE,
819+
attributeType),
820+
signature, null);
821+
var maxStack = 2;
822+
setterVisitor.visitCode();
823+
setterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
824+
setterVisitor.visitVarInsn(Opcodes.ALOAD, 1);
825+
if (typeHint.type().isInstance(PythonNone.INSTANCE)) {
826+
maxStack = 4;
827+
// We want to replace null with None
828+
setterVisitor.visitInsn(Opcodes.DUP);
829+
setterVisitor.visitInsn(Opcodes.ACONST_NULL);
830+
Label setFieldLabel = new Label();
831+
setterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, setFieldLabel);
832+
// set value is null, so we want Python to see it as None
833+
setterVisitor.visitInsn(Opcodes.POP);
834+
PythonConstantsImplementor.loadNone(setterVisitor);
835+
setterVisitor.visitLabel(setFieldLabel);
836+
// If branch is taken, stack is (non-null instance)
837+
// If branch is not taken, stack is None
838+
}
839+
setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName,
840+
attributeName, attributeType.getDescriptor());
841+
setterVisitor.visitInsn(Opcodes.RETURN);
842+
setterVisitor.visitMaxs(maxStack, 0);
843+
setterVisitor.visitEnd();
844+
}
845+
764846
private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) {
765847
var returnTypeHint = function.typeAnnotations.get("return");
766848
if (returnTypeHint != null) {
@@ -956,15 +1038,9 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter
9561038
methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
9571039
var type = fieldToType.get(field);
9581040
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
959-
Class<?> fieldType = type.getJavaObjectWrapperType();
9601041
methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field),
961-
Type.getDescriptor(fieldType));
962-
methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class));
963-
methodVisitor.visitInsn(Opcodes.DUP_X1);
964-
methodVisitor.visitInsn(Opcodes.DUP_X1);
965-
methodVisitor.visitInsn(Opcodes.POP);
966-
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class),
967-
"<init>", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false);
1042+
Type.getDescriptor(type.getJavaObjectWrapperType()));
1043+
getWrappedJavaObject(methodVisitor);
9681044
} else {
9691045
methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field),
9701046
'L' + type.getJavaTypeInternalName() + ';');
@@ -984,6 +1060,15 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter
9841060
methodVisitor.visitEnd();
9851061
}
9861062

1063+
private static void getWrappedJavaObject(MethodVisitor methodVisitor) {
1064+
methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class));
1065+
methodVisitor.visitInsn(Opcodes.DUP_X1);
1066+
methodVisitor.visitInsn(Opcodes.DUP_X1);
1067+
methodVisitor.visitInsn(Opcodes.POP);
1068+
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class),
1069+
"<init>", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false);
1070+
}
1071+
9871072
public static void createSetAttribute(ClassWriter classWriter, String classInternalName, String superInternalName,
9881073
Collection<String> instanceAttributes,
9891074
Map<String, PythonLikeType> fieldToType) {
@@ -1008,10 +1093,7 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter
10081093
String typeDescriptor = type.getJavaTypeDescriptor();
10091094
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
10101095
// Need to unwrap the object
1011-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class));
1012-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class),
1013-
"getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false);
1014-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName());
1096+
getUnwrappedJavaObject(methodVisitor, type);
10151097
typeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
10161098
} else {
10171099
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type.getJavaTypeInternalName());
@@ -1035,6 +1117,13 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter
10351117
methodVisitor.visitEnd();
10361118
}
10371119

1120+
private static void getUnwrappedJavaObject(MethodVisitor methodVisitor, PythonLikeType type) {
1121+
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class));
1122+
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class),
1123+
"getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false);
1124+
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName());
1125+
}
1126+
10381127
public static void createDeleteAttribute(ClassWriter classWriter, String classInternalName, String superInternalName,
10391128
Collection<String> instanceAttributes,
10401129
Map<String, PythonLikeType> fieldToType) {

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonLikeType.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ private PythonLikeType(String typeName, String internalName) {
107107
}
108108

109109
public static PythonLikeType getTypeForNewClass(String typeName, String internalName) {
110-
return new PythonLikeType(typeName, internalName);
110+
var out = new PythonLikeType(typeName, internalName);
111+
out.__dir__.put("__class__", out);
112+
return out;
111113
}
112114

113115
public void initializeNewType(List<PythonLikeType> superClassTypes) {

jpyinterpreter/tests/test_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ def my_method(self) -> Annotated[str, 'extra', JavaAnnotation(Deprecated, {
902902
assert annotations[0].forRemoval()
903903
assert annotations[0].since() == '0.0.0'
904904

905-
annotations = translated_class.getField('my_field').getAnnotations()
905+
annotations = translated_class.getMethod('getMy_field').getAnnotations()
906906
assert len(annotations) == 2
907907
assert isinstance(annotations[0], Deprecated)
908908
assert annotations[0].forRemoval()

tests/test_custom_shadow_variables.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,16 @@ class MyPlanningEntity:
4848
def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory):
4949
return [
5050
constraint_factory.for_each(MyPlanningEntity)
51-
.filter(lambda entity: entity.value is None)
52-
.penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD),
53-
constraint_factory.for_each(MyPlanningEntity)
54-
.filter(lambda entity: entity.value is not None and entity.value * 2 == entity.value_squared)
55-
.reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT)
51+
.filter(lambda entity: entity.value * 2 == entity.value_squared)
52+
.reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE)
5653
]
5754

5855
@timefold.solver.planning_solution
5956
@dataclass
6057
class MySolution:
6158
entity_list: Annotated[List[MyPlanningEntity], timefold.solver.PlanningEntityCollectionProperty]
6259
value_list: Annotated[List[int], timefold.solver.ValueRangeProvider]
63-
score: Annotated[timefold.solver.score.HardSoftScore, timefold.solver.PlanningScore] = field(default=None)
60+
score: Annotated[timefold.solver.score.SimpleScore, timefold.solver.PlanningScore] = field(default=None)
6461

6562
solver_config = timefold.solver.config.SolverConfig(
6663
solution_class=MySolution,
@@ -69,16 +66,15 @@ class MySolution:
6966
constraint_provider_function=my_constraints
7067
),
7168
termination_config=timefold.solver.config.TerminationConfig(
72-
best_score_limit='0hard/1soft'
69+
best_score_limit='1'
7370
)
7471
)
7572

7673
solver_factory = timefold.solver.SolverFactory.create(solver_config)
7774
solver = solver_factory.build_solver()
7875
problem = MySolution([MyPlanningEntity()], [1, 2, 3])
7976
solution: MySolution = solver.solve(problem)
80-
assert solution.score.hard_score() == 0
81-
assert solution.score.soft_score() == 1
77+
assert solution.score.score() == 1
8278
assert solution.entity_list[0].value == 2
8379
assert solution.entity_list[0].value_squared == 4
8480

@@ -127,20 +123,17 @@ class MyPlanningEntity:
127123
@timefold.solver.constraint_provider
128124
def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory):
129125
return [
130-
constraint_factory.for_each(MyPlanningEntity)
131-
.filter(lambda entity: entity.value is None)
132-
.penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD),
133126
constraint_factory.for_each(MyPlanningEntity)
134127
.filter(lambda entity: entity.twice_value == entity.value_squared)
135-
.reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT)
128+
.reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE)
136129
]
137130

138131
@timefold.solver.planning_solution
139132
@dataclass
140133
class MySolution:
141134
entity_list: Annotated[List[MyPlanningEntity], PlanningEntityCollectionProperty]
142135
value_list: Annotated[List[int], ValueRangeProvider]
143-
score: Annotated[timefold.solver.score.HardSoftScore, PlanningScore] = field(default=None)
136+
score: Annotated[timefold.solver.score.SimpleScore, PlanningScore] = field(default=None)
144137

145138
solver_config = timefold.solver.config.SolverConfig(
146139
solution_class=MySolution,
@@ -149,15 +142,14 @@ class MySolution:
149142
constraint_provider_function=my_constraints
150143
),
151144
termination_config=timefold.solver.config.TerminationConfig(
152-
best_score_limit='0hard/1soft'
145+
best_score_limit='1'
153146
)
154147
)
155148

156149
solver_factory = timefold.solver.SolverFactory.create(solver_config)
157150
solver = solver_factory.build_solver()
158151
problem = MySolution([MyPlanningEntity()], [1, 2, 3])
159152
solution: MySolution = solver.solve(problem)
160-
assert solution.score.hard_score() == 0
161-
assert solution.score.soft_score() == 1
153+
assert solution.score.score() == 1
162154
assert solution.entity_list[0].value == 2
163155
assert solution.entity_list[0].value_squared == 4

0 commit comments

Comments
 (0)