Skip to content

Commit 5e8c0bd

Browse files
chore: Implement IncrementalScoreCalculator using classes instead of decorators (#41)
- Since there can only be one function signature in Python, and Java allows many, it might be the case that the top function signature in Python does not match its parent's function signature. Since the interface calls the parent's function signature, the wrong method would be called. To prevent this, we need to look up the 'canonical' method of the type, which is conveniently stored as an attribute on the type. - Fix a bug in function __get__ descriptor; in particular, when called on a type, it should return the unbounded function instead of binding the function to the type. - Make the ABC check less strict. In particular, only collections.abc and Protocol are banned, since collections.abc contain classes that should be Protocols but are instead ABC, and Protocols only define the structure and do not play a part in type hierarchy.
1 parent aa61c30 commit 5e8c0bd

File tree

6 files changed

+239
-65
lines changed

6 files changed

+239
-65
lines changed

jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/builtins/FunctionBuiltinOperations.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import ai.timefold.jpyinterpreter.types.BoundPythonLikeFunction;
55
import ai.timefold.jpyinterpreter.types.PythonLikeFunction;
66
import ai.timefold.jpyinterpreter.types.PythonLikeType;
7+
import ai.timefold.jpyinterpreter.types.PythonNone;
78

89
public class FunctionBuiltinOperations {
910
public static PythonLikeObject bindFunctionToInstance(final PythonLikeFunction function, final PythonLikeObject instance,
1011
final PythonLikeType type) {
12+
if (instance == PythonNone.INSTANCE) {
13+
return function;
14+
}
1115
return new BoundPythonLikeFunction(instance, function);
1216
}
1317

jpyinterpreter/src/main/python/translator.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import inspect
44
import sys
55
import abc
6+
from typing import Protocol
7+
68
from jpype import JInt, JBoolean, JProxy, JClass, JArray
79

810

@@ -505,6 +507,7 @@ def force_update_type(python_type, java_type):
505507

506508

507509
def translate_python_class_to_java_class(python_class):
510+
import collections.abc as collections_abc
508511
from .annotations import erase_generic_args, convert_java_annotation, copy_type_annotations
509512
from .conversions import (
510513
init_type_to_compiled_java_class, is_banned_module, is_c_native, convert_to_java_python_like_object
@@ -523,16 +526,21 @@ def translate_python_class_to_java_class(python_class):
523526
if raw_type in type_to_compiled_java_class:
524527
return type_to_compiled_java_class[raw_type]
525528

526-
if python_class == abc.ABC or inspect.isabstract(python_class): # TODO: Implement a class for interfaces?
529+
if Protocol in python_class.__bases__:
527530
python_class_java_type = BuiltinTypes.BASE_TYPE
528531
type_to_compiled_java_class[python_class] = python_class_java_type
529532
return python_class_java_type
530533

531-
if hasattr(python_class, '__module__') and python_class.__module__ is not None and \
532-
is_banned_module(python_class.__module__):
533-
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))
534-
type_to_compiled_java_class[python_class] = python_class_java_type
535-
return python_class_java_type
534+
if hasattr(python_class, '__module__') and python_class.__module__ is not None:
535+
if python_class.__module__ == collections_abc.Collection.__module__:
536+
python_class_java_type = BuiltinTypes.BASE_TYPE
537+
type_to_compiled_java_class[python_class] = python_class_java_type
538+
return python_class_java_type
539+
540+
if is_banned_module(python_class.__module__):
541+
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))
542+
type_to_compiled_java_class[python_class] = python_class_java_type
543+
return python_class_java_type
536544

537545
if isinstance(python_class, JArray):
538546
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))

tests/test_incremental_score_calculator.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@ class Queen:
1616
column: int
1717
row: Annotated[Optional[int], PlanningVariable] = field(default=None)
1818

19-
def getColumnIndex(self):
19+
def get_column_index(self):
2020
return self.column
2121

22-
def getRowIndex(self):
22+
def get_row_index(self):
2323
if self.row is None:
2424
return -1
2525
return self.row
2626

27-
def getAscendingDiagonalIndex(self):
28-
return self.getColumnIndex() + self.getRowIndex()
27+
def get_ascending_diagonal_index(self):
28+
return self.get_column_index() + self.get_row_index()
2929

30-
def getDescendingDiagonalIndex(self):
31-
return self.getColumnIndex() - self.getRowIndex()
30+
def get_descending_diagonal_index(self):
31+
return self.get_column_index() - self.get_row_index()
3232

3333
def __eq__(self, other):
3434
return self.code == other.code
@@ -48,14 +48,13 @@ class Solution:
4848

4949

5050
def test_constraint_match_disabled_incremental_score_calculator():
51-
@incremental_score_calculator
52-
class IncrementalScoreCalculator:
51+
class NQueensIncrementalScoreCalculator(IncrementalScoreCalculator):
5352
score: int
5453
row_index_map: dict
5554
ascending_diagonal_index_map: dict
5655
descending_diagonal_index_map: dict
5756

58-
def resetWorkingSolution(self, working_solution: Solution):
57+
def reset_working_solution(self, working_solution: Solution):
5958
n = working_solution.n
6059
self.row_index_map = dict()
6160
self.ascending_diagonal_index_map = dict()
@@ -71,22 +70,22 @@ def resetWorkingSolution(self, working_solution: Solution):
7170
for queen in working_solution.queen_list:
7271
self.insert(queen)
7372

74-
def beforeEntityAdded(self, entity: any):
73+
def before_entity_added(self, entity: any):
7574
pass
7675

77-
def afterEntityAdded(self, entity: any):
76+
def after_entity_added(self, entity: any):
7877
self.insert(entity)
7978

80-
def beforeVariableChanged(self, entity: any, variableName: str):
79+
def before_variable_changed(self, entity: any, variableName: str):
8180
self.retract(entity)
8281

83-
def afterVariableChanged(self, entity: any, variableName: str):
82+
def after_variable_changed(self, entity: any, variableName: str):
8483
self.insert(entity)
8584

86-
def beforeEntityRemoved(self, entity: any):
85+
def before_entity_removed(self, entity: any):
8786
self.retract(entity)
8887

89-
def afterEntityRemoved(self, entity: any):
88+
def after_entity_removed(self, entity: any):
9089
pass
9190

9291
def insert(self, queen: Queen):
@@ -95,10 +94,10 @@ def insert(self, queen: Queen):
9594
row_index_list = self.row_index_map[row_index]
9695
self.score -= len(row_index_list)
9796
row_index_list.append(queen)
98-
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
97+
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
9998
self.score -= len(ascending_diagonal_index_list)
10099
ascending_diagonal_index_list.append(queen)
101-
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
100+
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
102101
self.score -= len(descending_diagonal_index_list)
103102
descending_diagonal_index_list.append(queen)
104103

@@ -108,21 +107,21 @@ def retract(self, queen: Queen):
108107
row_index_list = self.row_index_map[row_index]
109108
row_index_list.remove(queen)
110109
self.score += len(row_index_list)
111-
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
110+
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
112111
ascending_diagonal_index_list.remove(queen)
113112
self.score += len(ascending_diagonal_index_list)
114-
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
113+
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
115114
descending_diagonal_index_list.remove(queen)
116115
self.score += len(descending_diagonal_index_list)
117116

118-
def calculateScore(self) -> HardSoftScore:
117+
def calculate_score(self) -> HardSoftScore:
119118
return SimpleScore.of(self.score)
120119

121120
solver_config = SolverConfig(
122121
solution_class=Solution,
123122
entity_class_list=[Queen],
124123
score_director_factory_config=ScoreDirectorFactoryConfig(
125-
incremental_score_calculator_class=IncrementalScoreCalculator
124+
incremental_score_calculator_class=NQueensIncrementalScoreCalculator
126125
),
127126
termination_config=TerminationConfig(
128127
best_score_limit='0'
@@ -141,22 +140,22 @@ def calculateScore(self) -> HardSoftScore:
141140
right_queen = solution.queen_list[j]
142141
assert left_queen.row is not None and right_queen.row is not None
143142
assert left_queen.row != right_queen.row
144-
assert left_queen.getAscendingDiagonalIndex() != right_queen.getAscendingDiagonalIndex()
145-
assert left_queen.getDescendingDiagonalIndex() != right_queen.getDescendingDiagonalIndex()
143+
assert left_queen.get_ascending_diagonal_index() != right_queen.get_ascending_diagonal_index()
144+
assert left_queen.get_descending_diagonal_index() != right_queen.get_descending_diagonal_index()
146145

147146

148147
@pytest.mark.skip(reason="Special case where you want to convert all items of the list before returning."
149148
"Doing this for all conversions would be expensive."
150149
"This feature is not that important, so skipping for now.")
151150
def test_constraint_match_enabled_incremental_score_calculator():
152151
@incremental_score_calculator
153-
class IncrementalScoreCalculator:
152+
class NQueensIncrementalScoreCalculator(ConstraintMatchAwareIncrementalScoreCalculator):
154153
score: int
155154
row_index_map: dict
156155
ascending_diagonal_index_map: dict
157156
descending_diagonal_index_map: dict
158157

159-
def resetWorkingSolution(self, working_solution: Solution, constraint_match_enabled=False):
158+
def reset_working_solution(self, working_solution: Solution, constraint_match_enabled=False):
160159
n = working_solution.n
161160
self.row_index_map = dict()
162161
self.ascending_diagonal_index_map = dict()
@@ -172,22 +171,22 @@ def resetWorkingSolution(self, working_solution: Solution, constraint_match_enab
172171
for queen in working_solution.queen_list:
173172
self.insert(queen)
174173

175-
def beforeEntityAdded(self, entity: any):
174+
def before_entity_added(self, entity: any):
176175
pass
177176

178-
def afterEntityAdded(self, entity: any):
177+
def after_entity_added(self, entity: any):
179178
self.insert(entity)
180179

181-
def beforeVariableChanged(self, entity: any, variableName: str):
180+
def before_variable_changed(self, entity: any, variableName: str):
182181
self.retract(entity)
183182

184-
def afterVariableChanged(self, entity: any, variableName: str):
183+
def after_variable_changed(self, entity: any, variableName: str):
185184
self.insert(entity)
186185

187-
def beforeEntityRemoved(self, entity: any):
186+
def before_entity_removed(self, entity: any):
188187
self.retract(entity)
189188

190-
def afterEntityRemoved(self, entity: any):
189+
def after_entity_removed(self, entity: any):
191190
pass
192191

193192
def insert(self, queen: Queen):
@@ -197,10 +196,10 @@ def insert(self, queen: Queen):
197196
row_index_list = self.row_index_map[row_index]
198197
self.score -= len(row_index_list)
199198
row_index_list.append(queen)
200-
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
199+
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
201200
self.score -= len(ascending_diagonal_index_list)
202201
ascending_diagonal_index_list.append(queen)
203-
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
202+
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
204203
self.score -= len(descending_diagonal_index_list)
205204
descending_diagonal_index_list.append(queen)
206205

@@ -211,17 +210,17 @@ def retract(self, queen: Queen):
211210
row_index_list = self.row_index_map[row_index]
212211
row_index_list.remove(queen)
213212
self.score += len(row_index_list)
214-
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
213+
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
215214
ascending_diagonal_index_list.remove(queen)
216215
self.score += len(ascending_diagonal_index_list)
217-
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
216+
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
218217
descending_diagonal_index_list.remove(queen)
219218
self.score += len(descending_diagonal_index_list)
220219

221-
def calculateScore(self) -> HardSoftScore:
220+
def calculate_score(self) -> HardSoftScore:
222221
return SimpleScore.of(self.score)
223222

224-
def getConstraintMatchTotals(self):
223+
def get_constraint_match_totals(self):
225224
row_conflict_constraint_match_total = DefaultConstraintMatchTotal(
226225
'NQueens',
227226
'Row Conflict',
@@ -255,14 +254,14 @@ def getConstraintMatchTotals(self):
255254
descending_diagonal_constraint_match_total
256255
]
257256

258-
def getIndictmentMap(self):
257+
def get_indictment_map(self):
259258
return None
260259

261260
solver_config = SolverConfig(
262261
solution_class=Solution,
263262
entity_class_list=[Queen],
264263
score_director_factory_config=ScoreDirectorFactoryConfig(
265-
incremental_score_calculator_class=IncrementalScoreCalculator
264+
incremental_score_calculator_class=NQueensIncrementalScoreCalculator
266265
),
267266
termination_config=TerminationConfig(
268267
best_score_limit='0'
@@ -282,8 +281,8 @@ def getIndictmentMap(self):
282281
right_queen = solution.queen_list[j]
283282
assert left_queen.row is not None and right_queen.row is not None
284283
assert left_queen.row != right_queen.row
285-
assert left_queen.getAscendingDiagonalIndex() != right_queen.getAscendingDiagonalIndex()
286-
assert left_queen.getDescendingDiagonalIndex() != right_queen.getDescendingDiagonalIndex()
284+
assert left_queen.get_ascending_diagonal_index() != right_queen.get_ascending_diagonal_index()
285+
assert left_queen.get_descending_diagonal_index() != right_queen.get_descending_diagonal_index()
287286

288287
score_manager = SolutionManager.create(solver_factory)
289288
constraint_match_total_map = score_manager.explain(solution).constraint_match_total_map
@@ -315,21 +314,19 @@ def getIndictmentMap(self):
315314

316315

317316
def test_error_message_for_missing_methods():
318-
with pytest.raises(ValueError, match=(
319-
f"The following required methods are missing from @incremental_score_calculator class "
320-
f".*IncrementalScoreCalculatorMissingMethods.*: "
321-
f"\\['resetWorkingSolution', 'beforeEntityRemoved', 'afterEntityRemoved', 'calculateScore'\\]"
322-
)):
317+
with pytest.raises(TypeError): # Exact error message from ABC changes between versions
323318
@incremental_score_calculator
324-
class IncrementalScoreCalculatorMissingMethods:
325-
def beforeEntityAdded(self, entity: any):
319+
class IncrementalScoreCalculatorMissingMethods(IncrementalScoreCalculator):
320+
def before_entity_added(self, entity):
326321
pass
327322

328-
def afterEntityAdded(self, entity: any):
323+
def after_entity_added(self, entity):
329324
pass
330325

331-
def beforeVariableChanged(self, entity: any, variableName: str):
326+
def before_variable_changed(self, entity, variable_name: str):
332327
pass
333328

334-
def afterVariableChanged(self, entity: any, variableName: str):
329+
def after_variable_changed(self, entity, variable_name: str):
335330
pass
331+
332+
score_calculator = IncrementalScoreCalculatorMissingMethods()

timefold-solver-python-core/src/main/python/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from ._solution_manager import *
66
from ._score_director import *
77
from ._variable_listener import *
8+
from ._incremental_score_calculator import *

0 commit comments

Comments
 (0)