Skip to content

Commit 45132a1

Browse files
feat: Add support for custom justifications (#39)
- Added a way to specify a class implements a Java interface.
1 parent d0a5811 commit 45132a1

File tree

13 files changed

+567
-193
lines changed

13 files changed

+567
-193
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/InterfaceProxyGenerator.java

Lines changed: 6 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22

33
import java.lang.reflect.Method;
44
import java.lang.reflect.Modifier;
5-
import java.util.Collections;
65
import java.util.HashSet;
76
import java.util.IdentityHashMap;
8-
import java.util.List;
9-
import java.util.Map;
107
import java.util.Set;
118

9+
import ai.timefold.jpyinterpreter.implementors.DelegatingInterfaceImplementor;
1210
import ai.timefold.jpyinterpreter.implementors.JavaPythonTypeConversionImplementor;
1311
import ai.timefold.jpyinterpreter.types.BuiltinTypes;
1412
import ai.timefold.jpyinterpreter.types.PythonLikeType;
1513
import ai.timefold.jpyinterpreter.util.MethodVisitorAdapters;
1614
import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec;
1715

1816
import org.objectweb.asm.ClassWriter;
19-
import org.objectweb.asm.MethodVisitor;
2017
import org.objectweb.asm.Opcodes;
2118
import org.objectweb.asm.Type;
2219

@@ -251,64 +248,26 @@ private static void createMethodDelegate(ClassWriter classWriter,
251248
interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, wrapperInternalName,
252249
"argumentSpec$" + interfaceMethod.getName(),
253250
Type.getDescriptor(ArgumentSpec.class));
254-
interfaceMethodVisitor.visitLdcInsn(interfaceMethod.getParameterCount());
255-
interfaceMethodVisitor.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(PythonLikeObject.class));
256-
interfaceMethodVisitor.visitVarInsn(Opcodes.ASTORE, interfaceMethod.getParameterCount() + 2);
257-
for (int i = 0; i < interfaceMethod.getParameterCount(); i++) {
258-
var parameterType = interfaceMethod.getParameterTypes()[i];
259-
interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2);
260-
interfaceMethodVisitor.visitLdcInsn(i);
261-
interfaceMethodVisitor.visitVarInsn(Type.getType(parameterType).getOpcode(Opcodes.ILOAD),
262-
i + 1);
263-
if (parameterType.isPrimitive()) {
264-
convertPrimitiveToObjectType(parameterType, interfaceMethodVisitor);
265-
}
266-
interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 1);
267-
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC,
268-
Type.getInternalName(JavaPythonTypeConversionImplementor.class),
269-
"wrapJavaObject",
270-
Type.getMethodDescriptor(Type.getType(PythonLikeObject.class), Type.getType(Object.class), Type.getType(
271-
Map.class)),
272-
false);
273-
interfaceMethodVisitor.visitInsn(Opcodes.AASTORE);
274-
}
275251

276252
var functionSignature = delegateType.getMethodType(interfaceMethod.getName())
277253
.orElseThrow(() -> new IllegalArgumentException(
278254
"Type %s cannot implement interface %s because it missing method %s."
279255
.formatted(delegateType, interfaceMethod.getDeclaringClass(), interfaceMethod)))
280256
.getDefaultFunctionSignature()
281257
.orElseThrow();
282-
283-
interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2);
284-
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(List.class),
285-
"of", Type.getMethodDescriptor(Type.getType(List.class), Type.getType(Object[].class)),
286-
true);
287-
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Collections.class),
288-
"emptyMap", Type.getMethodDescriptor(Type.getType(Map.class)), false);
289-
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(ArgumentSpec.class),
290-
"extractArgumentList", Type.getMethodDescriptor(
291-
Type.getType(List.class), Type.getType(List.class), Type.getType(Map.class)),
258+
DelegatingInterfaceImplementor.prepareParametersForMethodCallFromArgumentSpec(
259+
interfaceMethod, interfaceMethodVisitor, functionSignature.getParameterTypes().length,
260+
Type.getType(functionSignature.getMethodDescriptor().getMethodDescriptor()),
292261
false);
293262

294-
for (int i = 0; i < functionSignature.getParameterTypes().length; i++) {
295-
interfaceMethodVisitor.visitInsn(Opcodes.DUP);
296-
interfaceMethodVisitor.visitLdcInsn(i);
297-
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(List.class),
298-
"get", Type.getMethodDescriptor(Type.getType(Object.class), Type.INT_TYPE), true);
299-
interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST,
300-
functionSignature.getParameterTypes()[i].getJavaTypeInternalName());
301-
interfaceMethodVisitor.visitInsn(Opcodes.SWAP);
302-
}
303-
interfaceMethodVisitor.visitInsn(Opcodes.POP);
304263
functionSignature.getMethodDescriptor().callMethod(interfaceMethodVisitor);
305264

306265
var returnType = interfaceMethod.getReturnType();
307266
if (returnType.equals(void.class)) {
308267
interfaceMethodVisitor.visitInsn(Opcodes.RETURN);
309268
} else {
310269
if (returnType.isPrimitive()) {
311-
loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor);
270+
DelegatingInterfaceImplementor.loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor);
312271
} else {
313272
interfaceMethodVisitor.visitLdcInsn(Type.getType(returnType));
314273
}
@@ -320,7 +279,7 @@ private static void createMethodDelegate(ClassWriter classWriter,
320279
PythonLikeObject.class)),
321280
false);
322281
if (returnType.isPrimitive()) {
323-
unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
282+
DelegatingInterfaceImplementor.unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor);
324283
interfaceMethodVisitor.visitInsn(Type.getType(returnType).getOpcode(Opcodes.IRETURN));
325284
} else {
326285
interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType));
@@ -330,94 +289,4 @@ private static void createMethodDelegate(ClassWriter classWriter,
330289
interfaceMethodVisitor.visitMaxs(interfaceMethod.getParameterCount() + 2, 1);
331290
interfaceMethodVisitor.visitEnd();
332291
}
333-
334-
private static void convertPrimitiveToObjectType(Class<?> primitiveType, MethodVisitor methodVisitor) {
335-
if (primitiveType.equals(boolean.class)) {
336-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Boolean.class),
337-
"valueOf", Type.getMethodDescriptor(Type.getType(Boolean.class), Type.BOOLEAN_TYPE), false);
338-
} else if (primitiveType.equals(byte.class)) {
339-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Byte.class),
340-
"valueOf", Type.getMethodDescriptor(Type.getType(Byte.class), Type.BYTE_TYPE), false);
341-
} else if (primitiveType.equals(char.class)) {
342-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Character.class),
343-
"valueOf", Type.getMethodDescriptor(Type.getType(Character.class), Type.CHAR_TYPE), false);
344-
} else if (primitiveType.equals(short.class)) {
345-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Short.class),
346-
"valueOf", Type.getMethodDescriptor(Type.getType(Short.class), Type.SHORT_TYPE), false);
347-
} else if (primitiveType.equals(int.class)) {
348-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Integer.class),
349-
"valueOf", Type.getMethodDescriptor(Type.getType(Integer.class), Type.INT_TYPE), false);
350-
} else if (primitiveType.equals(long.class)) {
351-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Long.class),
352-
"valueOf", Type.getMethodDescriptor(Type.getType(Long.class), Type.LONG_TYPE), false);
353-
} else if (primitiveType.equals(float.class)) {
354-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Float.class),
355-
"valueOf", Type.getMethodDescriptor(Type.getType(Float.class), Type.FLOAT_TYPE), false);
356-
} else if (primitiveType.equals(double.class)) {
357-
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Double.class),
358-
"valueOf", Type.getMethodDescriptor(Type.getType(Double.class), Type.DOUBLE_TYPE), false);
359-
} else {
360-
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
361-
}
362-
}
363-
364-
private static void loadBoxedPrimitiveTypeClass(Class<?> primitiveType, MethodVisitor methodVisitor) {
365-
if (primitiveType.equals(boolean.class)) {
366-
methodVisitor.visitLdcInsn(Type.getType(Boolean.class));
367-
} else if (primitiveType.equals(byte.class)) {
368-
methodVisitor.visitLdcInsn(Type.getType(Byte.class));
369-
} else if (primitiveType.equals(char.class)) {
370-
methodVisitor.visitLdcInsn(Type.getType(Character.class));
371-
} else if (primitiveType.equals(short.class)) {
372-
methodVisitor.visitLdcInsn(Type.getType(Short.class));
373-
} else if (primitiveType.equals(int.class)) {
374-
methodVisitor.visitLdcInsn(Type.getType(Integer.class));
375-
} else if (primitiveType.equals(long.class)) {
376-
methodVisitor.visitLdcInsn(Type.getType(Long.class));
377-
} else if (primitiveType.equals(float.class)) {
378-
methodVisitor.visitLdcInsn(Type.getType(Float.class));
379-
} else if (primitiveType.equals(double.class)) {
380-
methodVisitor.visitLdcInsn(Type.getType(Double.class));
381-
} else {
382-
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
383-
}
384-
}
385-
386-
private static void unboxBoxedPrimitiveType(Class<?> primitiveType, MethodVisitor methodVisitor) {
387-
if (primitiveType.equals(boolean.class)) {
388-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
389-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class),
390-
"booleanValue", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false);
391-
} else if (primitiveType.equals(byte.class)) {
392-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
393-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class),
394-
"byteValue", Type.getMethodDescriptor(Type.BYTE_TYPE), false);
395-
} else if (primitiveType.equals(char.class)) {
396-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
397-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class),
398-
"charValue", Type.getMethodDescriptor(Type.CHAR_TYPE), false);
399-
} else if (primitiveType.equals(short.class)) {
400-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
401-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class),
402-
"shortValue", Type.getMethodDescriptor(Type.SHORT_TYPE), false);
403-
} else if (primitiveType.equals(int.class)) {
404-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
405-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class),
406-
"intValue", Type.getMethodDescriptor(Type.INT_TYPE), false);
407-
} else if (primitiveType.equals(long.class)) {
408-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
409-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class),
410-
"longValue", Type.getMethodDescriptor(Type.LONG_TYPE), false);
411-
} else if (primitiveType.equals(float.class)) {
412-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
413-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class),
414-
"floatValue", Type.getMethodDescriptor(Type.FLOAT_TYPE), false);
415-
} else if (primitiveType.equals(double.class)) {
416-
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
417-
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class),
418-
"doubleValue", Type.getMethodDescriptor(Type.DOUBLE_TYPE), false);
419-
} else {
420-
throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType));
421-
}
422-
}
423292
}

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.stream.Collectors;
2020

2121
import ai.timefold.jpyinterpreter.dag.FlowGraph;
22+
import ai.timefold.jpyinterpreter.implementors.DelegatingInterfaceImplementor;
2223
import ai.timefold.jpyinterpreter.implementors.JavaComparableImplementor;
2324
import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor;
2425
import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor;
@@ -94,6 +95,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
9495
var className = preparedClassInfo.className;
9596
var internalClassName = preparedClassInfo.classInternalName;
9697

98+
Map<String, InterfaceDeclaration> instanceMethodNameToMethodDescriptor = new HashMap<>();
9799
Set<PythonLikeType> superTypeSet;
98100
Set<JavaInterfaceImplementor> javaInterfaceImplementorSet = new HashSet<>();
99101

@@ -118,6 +120,11 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
118120
}
119121
}
120122

123+
for (Class<?> javaInterface : pythonCompiledClass.javaInterfaces) {
124+
javaInterfaceImplementorSet.add(
125+
new DelegatingInterfaceImplementor(internalClassName, javaInterface, instanceMethodNameToMethodDescriptor));
126+
}
127+
121128
if (pythonCompiledClass.superclassList.isEmpty()) {
122129
superTypeSet = Set.of(CPythonBackedPythonLikeObject.CPYTHON_BACKED_OBJECT_TYPE);
123130
} else {
@@ -159,7 +166,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
159166

160167
List<JavaInterfaceImplementor> nonObjectInterfaceImplementors = javaInterfaceImplementorSet.stream()
161168
.filter(implementor -> !Object.class.equals(implementor.getInterfaceClass()))
162-
.collect(Collectors.toList());
169+
.toList();
170+
163171
String[] interfaces = new String[nonObjectInterfaceImplementors.size()];
164172
for (int i = 0; i < nonObjectInterfaceImplementors.size(); i++) {
165173
interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass());
@@ -294,7 +302,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
294302
.entrySet()) {
295303
instanceMethodEntry.getValue().methodKind = PythonMethodKind.VIRTUAL_METHOD;
296304
createInstanceMethod(pythonLikeType, classWriter, internalClassName, instanceMethodEntry.getKey(),
297-
instanceMethodEntry.getValue());
305+
instanceMethodEntry.getValue(), instanceMethodNameToMethodDescriptor);
298306
}
299307

300308
for (Map.Entry<String, PythonCompiledFunction> staticMethodEntry : pythonCompiledClass.staticFunctionNameToPythonBytecode
@@ -854,13 +862,15 @@ private static void addAnnotationsToMethod(PythonCompiledFunction function, Meth
854862
}
855863

856864
private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWriter classWriter, String internalClassName,
857-
String methodName, PythonCompiledFunction function) {
865+
String methodName, PythonCompiledFunction function,
866+
Map<String, InterfaceDeclaration> instanceMethodNameToMethodDescriptor) {
858867
InterfaceDeclaration interfaceDeclaration = getInterfaceForInstancePythonFunction(internalClassName, function);
859-
String interfaceDescriptor = 'L' + interfaceDeclaration.interfaceName + ';';
868+
String interfaceDescriptor = interfaceDeclaration.descriptor();
860869
String javaMethodName = getJavaMethodName(methodName);
861870

862871
classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor,
863872
null, null);
873+
instanceMethodNameToMethodDescriptor.put(methodName, interfaceDeclaration);
864874
Type returnType = getVirtualFunctionReturnType(function);
865875

866876
List<PythonLikeType> parameterPythonTypeList = function.getParameterTypes();
@@ -1555,30 +1565,13 @@ public static PythonLikeType getPythonReturnTypeOfFunction(PythonCompiledFunctio
15551565
}
15561566
}
15571567

1558-
public static class InterfaceDeclaration {
1559-
final String interfaceName;
1560-
final String methodDescriptor;
1561-
1562-
public InterfaceDeclaration(String interfaceName, String methodDescriptor) {
1563-
this.interfaceName = interfaceName;
1564-
this.methodDescriptor = methodDescriptor;
1565-
}
1566-
1567-
@Override
1568-
public boolean equals(Object o) {
1569-
if (this == o) {
1570-
return true;
1571-
}
1572-
if (o == null || getClass() != o.getClass()) {
1573-
return false;
1574-
}
1575-
InterfaceDeclaration that = (InterfaceDeclaration) o;
1576-
return interfaceName.equals(that.interfaceName) && methodDescriptor.equals(that.methodDescriptor);
1568+
public record InterfaceDeclaration(String interfaceName, String methodDescriptor) {
1569+
public String descriptor() {
1570+
return "L" + interfaceName + ";";
15771571
}
15781572

1579-
@Override
1580-
public int hashCode() {
1581-
return Objects.hash(interfaceName, methodDescriptor);
1573+
public Type methodType() {
1574+
return Type.getMethodType(methodDescriptor);
15821575
}
15831576
}
15841577

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public class PythonCompiledClass {
3636
*/
3737
public Map<String, TypeHint> typeAnnotations;
3838

39+
/**
40+
* Java interfaces the class implement
41+
*/
42+
public List<Class<?>> javaInterfaces;
43+
3944
/**
4045
* The binary type of this PythonCompiledClass;
4146
* typically {@link CPythonType}. Used when methods

0 commit comments

Comments
 (0)