Skip to content

Commit a573b12

Browse files
fix: Load globals that are subclasses or instances of a class being compiled lazily (#67)
Previously, the entire global dict was loaded whenever a function was compiled. However, this causes issues if there is a global that is an instance or subclass of a class being compiled, and the class being compiled reference that global. Now, when a class is being compiled, instances and subclasses of that class is excluded and loaded at runtime.
1 parent dc5fc78 commit a573b12

File tree

5 files changed

+75
-8
lines changed

5 files changed

+75
-8
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/CPythonBackedPythonInterpreter.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class CPythonBackedPythonInterpreter implements PythonInterpreter {
4646
public static BiFunction<OpaquePythonReference, String, PythonLikeObject> lookupAttributeOnPythonReferencePythonFunction;
4747
public static BiFunction<OpaquePythonReference, String, OpaquePythonReference> lookupPointerForAttributeOnPythonReferencePythonFunction;
4848
public static BiFunction<OpaquePythonReference, String, OpaquePythonReference[]> lookupPointerArrayForAttributeOnPythonReferencePythonFunction;
49+
public static BiConsumer<Map<String, PythonLikeObject>, String> loadObjectFromPythonGlobalDict;
4950

5051
public static TriFunction<OpaquePythonReference, String, Map<Number, PythonLikeObject>, PythonLikeObject> lookupAttributeOnPythonReferenceWithMapPythonFunction;
5152
public static QuadConsumer<OpaquePythonReference, OpaquePythonReference, String, Object> setAttributeOnPythonReferencePythonFunction;
@@ -154,6 +155,11 @@ public void setPythonReference(PythonLikeObject instance, OpaquePythonReference
154155

155156
@Override
156157
public PythonLikeObject getGlobal(Map<String, PythonLikeObject> globalsMap, String name) {
158+
if (!globalsMap.containsKey(name)) {
159+
// This will put 'null' in the map if it doesn't exist, so we don't
160+
// do an expensive CPython lookup everytime we are getting an attribute
161+
loadObjectFromPythonGlobalDict.accept(globalsMap, name);
162+
}
157163
PythonLikeObject out = globalsMap.get(name);
158164
if (out == null) {
159165
return GlobalBuiltins.lookupOrError(this, name);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package ai.timefold.jpyinterpreter.util;
2+
3+
import java.util.HashMap;
4+
5+
import ai.timefold.jpyinterpreter.PythonLikeObject;
6+
7+
public class PythonGlobalsBackedMap extends HashMap<String, PythonLikeObject> {
8+
private final long pythonGlobalsId;
9+
10+
public PythonGlobalsBackedMap(long pythonGlobalsId) {
11+
this.pythonGlobalsId = pythonGlobalsId;
12+
}
13+
14+
public long getPythonGlobalsId() {
15+
return pythonGlobalsId;
16+
}
17+
}

jpyinterpreter/src/main/python/jvm_setup.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def init(*args, path: List[str] = None, include_translator_jars: bool = True,
4141
CPythonBackedPythonInterpreter.lookupPythonReferenceIdPythonFunction = GetPythonObjectId()
4242
CPythonBackedPythonInterpreter.lookupPythonReferenceTypePythonFunction = GetPythonObjectType()
4343
CPythonBackedPythonInterpreter.lookupAttributeOnPythonReferencePythonFunction = GetAttributeOnPythonObject()
44+
CPythonBackedPythonInterpreter.loadObjectFromPythonGlobalDict = GetNameFromGlobals()
4445
CPythonBackedPythonInterpreter.lookupPointerForAttributeOnPythonReferencePythonFunction = \
4546
GetAttributePointerOnPythonObject()
4647
CPythonBackedPythonInterpreter.lookupPointerArrayForAttributeOnPythonReferencePythonFunction = \
@@ -55,6 +56,25 @@ def init(*args, path: List[str] = None, include_translator_jars: bool = True,
5556
CPythonBackedPythonInterpreter.importModuleFunction = ImportModule()
5657

5758

59+
@jpype.JImplements('java.util.function.BiConsumer', deferred=True)
60+
class GetNameFromGlobals:
61+
@jpype.JOverride()
62+
def accept(self, java_globals, name):
63+
from .translator import java_globals_to_python_globals
64+
from .conversions import convert_to_java_python_like_object
65+
from ai.timefold.jpyinterpreter.util import PythonGlobalsBackedMap
66+
67+
if not isinstance(java_globals, PythonGlobalsBackedMap):
68+
return
69+
70+
python_globals = java_globals_to_python_globals[java_globals.getPythonGlobalsId()]
71+
try:
72+
python_object = python_globals[name]
73+
java_globals.put(name, convert_to_java_python_like_object(python_object))
74+
except KeyError:
75+
java_globals.put(name, None)
76+
77+
5878
@jpype.JImplements('java.util.function.Function', deferred=True)
5979
class GetPythonObjectId:
6080
@jpype.JOverride()

jpyinterpreter/src/main/python/translator.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
global_dict_to_instance = dict()
1515
global_dict_to_key_set = dict()
16+
java_globals_to_python_globals = dict()
17+
1618
type_to_compiled_java_class = dict()
1719
type_to_annotations = dict()
1820
type_to_java_interfaces = dict()
@@ -136,26 +138,33 @@ def copy_closure(closure):
136138
return out
137139

138140

139-
def copy_globals(globals_dict, co_names):
141+
def copy_globals(globals_dict, co_names, python_class):
140142
global global_dict_to_instance
141143
global global_dict_to_key_set
142144
from .conversions import convert_to_java_python_like_object
143-
from java.util import HashMap
145+
from ai.timefold.jpyinterpreter.util import PythonGlobalsBackedMap
144146
from ai.timefold.jpyinterpreter import CPythonBackedPythonInterpreter
145147

146148
globals_dict_key = id(globals_dict)
147149
if globals_dict_key in global_dict_to_instance:
148150
out = global_dict_to_instance[globals_dict_key]
149151
key_set = global_dict_to_key_set[globals_dict_key]
150152
else:
151-
out = HashMap()
153+
out = PythonGlobalsBackedMap(globals_dict_key)
152154
key_set = set()
153155
global_dict_to_instance[globals_dict_key] = out
154156
global_dict_to_key_set[globals_dict_key] = key_set
157+
java_globals_to_python_globals[globals_dict_key] = globals_dict
155158

156159
instance_map = CPythonBackedPythonInterpreter.pythonObjectIdToConvertedObjectMap
157160
for key, value in globals_dict.items():
158161
if key not in key_set and key in co_names:
162+
if python_class is not None:
163+
if isinstance(value, type):
164+
if issubclass(value, python_class):
165+
continue
166+
elif isinstance(value, python_class):
167+
continue
159168
key_set.add(key)
160169
out.put(key, convert_to_java_python_like_object(value, instance_map))
161170
return out
@@ -216,7 +225,7 @@ def get_python_exception_table(python_code):
216225
return out
217226

218227

219-
def get_function_bytecode_object(python_function):
228+
def get_function_bytecode_object(python_function, python_class: type = None):
220229
from .annotations import copy_type_annotations
221230
from .conversions import copy_iterable, init_type_to_compiled_java_class, convert_to_java_python_like_object
222231
from java.util import ArrayList
@@ -251,7 +260,8 @@ def get_function_bytecode_object(python_function):
251260
python_compiled_function.co_argcount = python_function.__code__.co_argcount
252261
python_compiled_function.co_kwonlyargcount = python_function.__code__.co_kwonlyargcount
253262
python_compiled_function.closure = copy_closure(python_function.__closure__)
254-
python_compiled_function.globalsMap = copy_globals(python_function.__globals__, python_function.__code__.co_names)
263+
python_compiled_function.globalsMap = copy_globals(python_function.__globals__, python_function.__code__.co_names,
264+
python_class)
255265
python_compiled_function.typeAnnotations = copy_type_annotations(python_function,
256266
get_default_args(python_function),
257267
inspect.getfullargspec(python_function).varargs,
@@ -267,7 +277,7 @@ def get_function_bytecode_object(python_function):
267277

268278

269279
def get_static_function_bytecode_object(the_class, python_function):
270-
return get_function_bytecode_object(python_function.__get__(the_class))
280+
return get_function_bytecode_object(python_function.__get__(the_class), python_class=the_class)
271281

272282

273283
def copy_variable_names(iterable):
@@ -673,7 +683,7 @@ def translate_python_class_to_java_class(python_class):
673683

674684
instance_method_map = HashMap()
675685
for method in instance_methods:
676-
instance_method_map.put(method[0], get_function_bytecode_object(method[1]))
686+
instance_method_map.put(method[0], get_function_bytecode_object(method[1], python_class=python_class))
677687

678688
static_attributes_map = HashMap()
679689
static_attributes_to_class_instance_map = HashMap()
@@ -683,7 +693,7 @@ def translate_python_class_to_java_class(python_class):
683693

684694
for attribute in static_attributes:
685695
attribute_type = type(attribute[1])
686-
if attribute_type == python_class:
696+
if issubclass(attribute_type, python_class):
687697
static_attributes_to_class_instance_map.put(attribute[0],
688698
JProxy(OpaquePythonReference,
689699
inst=attribute[1], convert=True))

jpyinterpreter/tests/test_classes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,20 @@ def my_function_instance(a, b, c):
896896
verifier.verify(1, 1, 1, expected_result=13)
897897

898898

899+
def test_enum_translate_to_class():
900+
from enum import Enum
901+
from jpyinterpreter import translate_python_class_to_java_class
902+
from ai.timefold.jpyinterpreter.types.wrappers import CPythonType
903+
904+
class Color(Enum):
905+
RED = 'RED'
906+
GREEN = 'GREEN'
907+
BLUE = 'BLUE'
908+
909+
translated_class = translate_python_class_to_java_class(Color)
910+
assert not isinstance(translated_class, CPythonType)
911+
912+
899913
def test_class_annotations():
900914
from typing import Annotated
901915
from java.lang import Deprecated

0 commit comments

Comments
 (0)