Skip to content

Commit e3214be

Browse files
chore: Handle forward references, repeatable annotations, and use str enums (#43)
- Previously, we eagerly compile the class as soon as `@planning_entity` or `@planning_solution` is reached. This cause problem if the class contains forward references, since: - The referenced class does not exist when `@planning_entity` or `@planning_solution` is reached, since it is defined later. - This causes `get_type_hints` to raise a NameError, since it cannot find the type name in locals or globals Now, the classes are compiled when the SolverConfig is read (and thus, all the referenced classes should be defined). - In order to handle forward references in annotations and the class translator, usage of getJavaClass (which may throw an exception if the class is in the middle of being defined) need to be changed to getJavaClassInternalName (which will never throw). Thus: - ArgumentSpec now take the return type and parameters of a function as String instead of their Class. There could be a Class overload that calls the String version, but I decided against it to prevent the temptation to pass getJavaClass to it. - PythonDefaultArgumentImplementor now sets its constants inside <clinit>, which means ArgumentSpec must store itself in a static field so clinit can access it. This is because if a Class cannot be loaded until all the Classes it references are defined (and thus, we cannot set the field values from Java, since PythonDefaultArgumentImplementor might reference a class still being defined). - AnnotationMetadata now store Class values as Type instead of Class. - In order for Timefold to discover the "true" type of annotated getters, we remove NoneType from the Union of the getter return type (but keep it in the actual field type). That is, if a type is annotated `Value | None`, the getter type will be `Value`, and the field type will be `get_common_ancestor(Value, NoneType) = object`. - Use PythonClassWriter instead of ClassWriter in the ClassTranslator, so ASM will not complain about missing classes when computing frames. - In order to properly visit a repeatable annotation, we need to group the repeated annotation by type and put them as the value of their container annotation class. - Make VersionMapping throw an exception if it gets a null version mapping (otherwise, an undescriptive NPE will happen if the unsupported opcode is in the bytecode). - Added the missing PreviousElementShadowVariable and NextElementShadowVariable annotations. - Use str enums, so the enum serializes to their names instead of a number.
1 parent 5e8c0bd commit e3214be

25 files changed

+909
-356
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package ai.timefold.jpyinterpreter;
22

33
import java.lang.annotation.Annotation;
4+
import java.lang.annotation.Repeatable;
45
import java.lang.reflect.Array;
6+
import java.util.ArrayList;
7+
import java.util.LinkedHashMap;
8+
import java.util.List;
59
import java.util.Map;
610

711
import org.objectweb.asm.AnnotationVisitor;
@@ -23,6 +27,30 @@ public void addAnnotationTo(MethodVisitor methodVisitor) {
2327
visitAnnotation(methodVisitor.visitAnnotation(Type.getDescriptor(annotationType), true));
2428
}
2529

30+
public static List<AnnotationMetadata> getAnnotationListWithoutRepeatable(List<AnnotationMetadata> metadata) {
31+
List<AnnotationMetadata> out = new ArrayList<>();
32+
Map<Class<? extends Annotation>, List<AnnotationMetadata>> repeatableAnnotationMap = new LinkedHashMap<>();
33+
for (AnnotationMetadata annotation : metadata) {
34+
Repeatable repeatable = annotation.annotationType().getAnnotation(Repeatable.class);
35+
if (repeatable == null) {
36+
out.add(annotation);
37+
continue;
38+
}
39+
var annotationContainer = repeatable.value();
40+
repeatableAnnotationMap.computeIfAbsent(annotationContainer,
41+
ignored -> new ArrayList<>()).add(annotation);
42+
}
43+
for (var entry : repeatableAnnotationMap.entrySet()) {
44+
out.add(new AnnotationMetadata(entry.getKey(),
45+
Map.of("value", entry.getValue().toArray(AnnotationMetadata[]::new))));
46+
}
47+
return out;
48+
}
49+
50+
public static Type getValueAsType(String className) {
51+
return Type.getType("L" + className.replace('.', '/') + ";");
52+
}
53+
2654
private void visitAnnotation(AnnotationVisitor annotationVisitor) {
2755
for (var entry : annotationValueMap.entrySet()) {
2856
var annotationAttributeName = entry.getKey();
@@ -42,8 +70,8 @@ private void visitAnnotationAttribute(AnnotationVisitor annotationVisitor, Strin
4270
return;
4371
}
4472

45-
if (attributeValue instanceof Class<?> clazz) {
46-
annotationVisitor.visit(attributeName, Type.getType(clazz));
73+
if (attributeValue instanceof Type type) {
74+
annotationVisitor.visit(attributeName, type);
4775
return;
4876
}
4977

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple;
4747
import ai.timefold.jpyinterpreter.types.wrappers.JavaObjectWrapper;
4848
import ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference;
49+
import ai.timefold.jpyinterpreter.util.JavaPythonClassWriter;
4950
import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec;
5051

5152
import org.objectweb.asm.ClassWriter;
@@ -122,7 +123,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
122123

123124
for (Class<?> javaInterface : pythonCompiledClass.javaInterfaces) {
124125
javaInterfaceImplementorSet.add(
125-
new DelegatingInterfaceImplementor(internalClassName, javaInterface, instanceMethodNameToMethodDescriptor));
126+
new DelegatingInterfaceImplementor(internalClassName, javaInterface,
127+
instanceMethodNameToMethodDescriptor));
126128
}
127129

128130
if (pythonCompiledClass.superclassList.isEmpty()) {
@@ -173,14 +175,14 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
173175
interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass());
174176
}
175177

176-
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
178+
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
177179

178180
classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null,
179181
superClassType.getJavaTypeInternalName(), interfaces);
180182

181183
classWriter.visitSource(pythonCompiledClass.moduleFilePath, null);
182184

183-
for (var annotation : pythonCompiledClass.annotations) {
185+
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(pythonCompiledClass.annotations)) {
184186
annotation.addAnnotationTo(classWriter);
185187
}
186188

@@ -208,19 +210,23 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
208210
var typeHint = pythonCompiledClass.typeAnnotations.getOrDefault(attributeName,
209211
TypeHint.withoutAnnotations(BuiltinTypes.BASE_TYPE));
210212
PythonLikeType type = typeHint.type();
213+
PythonLikeType javaGetterType = typeHint.javaGetterType();
211214
if (type == null) { // null might be in __annotations__
212215
type = BuiltinTypes.BASE_TYPE;
213216
}
214217

215218
attributeNameToTypeMap.put(attributeName, type);
216219
FieldVisitor fieldVisitor;
217220
String javaFieldTypeDescriptor;
221+
String getterTypeDescriptor;
218222
String signature = null;
219223
boolean isJavaType;
220224
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
221225
javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
222-
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
223-
null, null);
226+
getterTypeDescriptor = javaFieldTypeDescriptor;
227+
fieldVisitor =
228+
classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
229+
null, null);
224230
isJavaType = true;
225231
} else {
226232
if (typeHint.genericArgs() != null) {
@@ -229,16 +235,21 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
229235
signature = signatureWriter.toString();
230236
}
231237
javaFieldTypeDescriptor = 'L' + type.getJavaTypeInternalName() + ';';
232-
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
233-
signature, null);
238+
getterTypeDescriptor = javaGetterType.getJavaTypeDescriptor();
239+
fieldVisitor =
240+
classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
241+
signature, null);
234242
isJavaType = false;
235243
}
236244
fieldVisitor.visitEnd();
245+
237246
createJavaGetterSetter(classWriter, preparedClassInfo,
238247
attributeName,
239248
Type.getType(javaFieldTypeDescriptor),
249+
Type.getType(getterTypeDescriptor),
240250
signature,
241251
typeHint);
252+
242253
FieldDescriptor fieldDescriptor =
243254
new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName,
244255
javaFieldTypeDescriptor, type, true, isJavaType);
@@ -344,7 +355,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
344355
pythonLikeType.$setAttribute("__module__", PythonString.valueOf(pythonCompiledClass.module));
345356

346357
PythonLikeDict annotations = new PythonLikeDict();
347-
pythonCompiledClass.typeAnnotations.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type()));
358+
pythonCompiledClass.typeAnnotations
359+
.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type()));
348360
pythonLikeType.$setAttribute("__annotations__", annotations);
349361

350362
PythonLikeTuple mro = new PythonLikeTuple();
@@ -552,7 +564,7 @@ private static Class<?> createPythonWrapperMethod(String methodName, PythonCompi
552564
String className = maybeClassName;
553565
String internalClassName = className.replace('.', '/');
554566

555-
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
567+
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
556568

557569
classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null,
558570
Type.getInternalName(Object.class), new String[] { interfaceDeclaration.interfaceName });
@@ -663,7 +675,7 @@ private static PythonLikeFunction createConstructor(String classInternalName,
663675
String constructorClassName = maybeClassName;
664676
String constructorInternalClassName = constructorClassName.replace('.', '/');
665677

666-
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
678+
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
667679
classWriter.visit(Opcodes.V11, Modifier.PUBLIC, constructorInternalClassName, null, Type.getInternalName(Object.class),
668680
new String[] {
669681
Type.getInternalName(PythonLikeFunction.class)
@@ -775,24 +787,24 @@ private static PythonLikeFunction createConstructor(String classInternalName,
775787

776788
private static void createJavaGetterSetter(ClassWriter classWriter,
777789
PreparedClassInfo preparedClassInfo,
778-
String attributeName, Type attributeType,
790+
String attributeName, Type attributeType, Type getterType,
779791
String signature,
780792
TypeHint typeHint) {
781-
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
782-
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
793+
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint);
794+
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint);
783795
}
784796

785797
private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
786-
Type attributeType, String signature, TypeHint typeHint) {
798+
Type attributeType, Type getterType, String signature, TypeHint typeHint) {
787799
var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
788-
if (signature != null) {
800+
if (signature != null && Objects.equals(attributeType, getterType)) {
789801
signature = "()" + signature;
790802
}
791-
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType),
803+
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(getterType),
792804
signature, null);
793805
var maxStack = 1;
794806

795-
for (var annotation : typeHint.annotationList()) {
807+
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(typeHint.annotationList())) {
796808
annotation.addAnnotationTo(getterVisitor);
797809
}
798810

@@ -813,19 +825,22 @@ private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo
813825
// If branch is taken, stack is field
814826
// If branch is not taken, stack is null
815827
}
828+
if (!Objects.equals(attributeType, getterType)) {
829+
getterVisitor.visitTypeInsn(Opcodes.CHECKCAST, getterType.getInternalName());
830+
}
816831
getterVisitor.visitInsn(Opcodes.ARETURN);
817832
getterVisitor.visitMaxs(maxStack, 0);
818833
getterVisitor.visitEnd();
819834
}
820835

821836
private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
822-
Type attributeType, String signature, TypeHint typeHint) {
837+
Type attributeType, Type setterType, String signature, TypeHint typeHint) {
823838
var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
824-
if (signature != null) {
839+
if (signature != null && Objects.equals(attributeType, setterType)) {
825840
signature = "(" + signature + ")V";
826841
}
827842
var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE,
828-
attributeType),
843+
setterType),
829844
signature, null);
830845
var maxStack = 2;
831846
setterVisitor.visitCode();
@@ -845,6 +860,9 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
845860
// If branch is taken, stack is (non-null instance)
846861
// If branch is not taken, stack is None
847862
}
863+
if (!Objects.equals(attributeType, setterType)) {
864+
setterVisitor.visitTypeInsn(Opcodes.CHECKCAST, attributeType.getInternalName());
865+
}
848866
setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName,
849867
attributeName, attributeType.getDescriptor());
850868
setterVisitor.visitInsn(Opcodes.RETURN);
@@ -855,7 +873,7 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
855873
private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) {
856874
var returnTypeHint = function.typeAnnotations.get("return");
857875
if (returnTypeHint != null) {
858-
for (var annotation : returnTypeHint.annotationList()) {
876+
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(returnTypeHint.annotationList())) {
859877
annotation.addAnnotationTo(methodVisitor);
860878
}
861879
}
@@ -1444,7 +1462,7 @@ public static InterfaceDeclaration createInterfaceForFunctionSignature(FunctionS
14441462

14451463
String internalClassName = className.replace('.', '/');
14461464

1447-
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
1465+
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
14481466
classWriter.visit(Opcodes.V11, Modifier.PUBLIC | Modifier.INTERFACE | Modifier.ABSTRACT, internalClassName, null,
14491467
Type.getInternalName(Object.class), null);
14501468

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,16 @@ private static <T> Class<T> getParameterJavaClass(List<PythonLikeType> parameter
208208
return (Class) parameterTypeList.get(variableIndex).getJavaClassOrDefault(PythonLikeObject.class);
209209
}
210210

211+
private static String getParameterJavaClassName(List<PythonLikeType> parameterTypeList, int variableIndex) {
212+
return parameterTypeList.get(variableIndex).getJavaTypeInternalName();
213+
}
214+
211215
@SuppressWarnings({ "unchecked", "rawtypes" })
212216
public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject>> getArgumentSpecMapper() {
213217
return (defaultPositionalArguments, defaultKeywordArguments) -> {
214218
ArgumentSpec<PythonLikeObject> out = ArgumentSpec.forFunctionReturning(qualifiedName, getReturnType()
215-
.map(type -> (Class) type.getJavaClassOrDefault(PythonLikeObject.class))
216-
.orElse(PythonLikeObject.class));
219+
.map(PythonLikeType::getJavaTypeInternalName)
220+
.orElse(PythonLikeObject.class.getName()));
217221

218222
int variableIndex = 0;
219223
int defaultPositionalStartIndex = co_argcount - defaultPositionalArguments.size();
@@ -226,23 +230,23 @@ public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject
226230
for (; variableIndex < co_posonlyargcount; variableIndex++) {
227231
if (variableIndex >= defaultPositionalStartIndex) {
228232
out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex),
229-
getParameterJavaClass(parameterTypeList, variableIndex),
233+
getParameterJavaClassName(parameterTypeList, variableIndex),
230234
defaultPositionalArguments.get(
231235
variableIndex - defaultPositionalStartIndex));
232236
} else {
233237
out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex),
234-
getParameterJavaClass(parameterTypeList, variableIndex));
238+
getParameterJavaClassName(parameterTypeList, variableIndex));
235239
}
236240
}
237241

238242
for (; variableIndex < co_argcount; variableIndex++) {
239243
if (variableIndex >= defaultPositionalStartIndex) {
240244
out = out.addArgument(co_varnames.get(variableIndex),
241-
getParameterJavaClass(parameterTypeList, variableIndex),
245+
getParameterJavaClassName(parameterTypeList, variableIndex),
242246
defaultPositionalArguments.get(variableIndex - defaultPositionalStartIndex));
243247
} else {
244248
out = out.addArgument(co_varnames.get(variableIndex),
245-
getParameterJavaClass(parameterTypeList, variableIndex));
249+
getParameterJavaClassName(parameterTypeList, variableIndex));
246250
}
247251
}
248252

@@ -251,11 +255,11 @@ public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject
251255
defaultKeywordArguments.get(PythonString.valueOf(co_varnames.get(variableIndex)));
252256
if (maybeDefault != null) {
253257
out = out.addKeywordOnlyArgument(co_varnames.get(variableIndex),
254-
getParameterJavaClass(parameterTypeList, variableIndex),
258+
getParameterJavaClassName(parameterTypeList, variableIndex),
255259
maybeDefault);
256260
} else {
257261
out = out.addKeywordOnlyArgument(co_varnames.get(variableIndex),
258-
getParameterJavaClass(parameterTypeList, variableIndex));
262+
getParameterJavaClassName(parameterTypeList, variableIndex));
259263
}
260264
variableIndex++;
261265
}

0 commit comments

Comments
 (0)