Skip to content

Commit aa61c30

Browse files
chore: Implement VariableListener using base classes instead of decorators (#40)
- Made all methods as optional, since 99% of the time, they do nothing
1 parent 45132a1 commit aa61c30

File tree

7 files changed

+208
-146
lines changed

7 files changed

+208
-146
lines changed

tests/test_constraint_streams.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ def define_constraints(constraint_factory: ConstraintFactory):
589589
'rewardConfigurableLong',
590590
'rewardLong',
591591
'_handler', # JPype handler field should be ignored
592+
# Unimplemented
593+
'toConnectedRanges',
594+
'toConnectedTemporalRanges',
592595
# These methods are deprecated
593596
'from_',
594597
'fromUnfiltered',

tests/test_custom_shadow_variables.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,14 @@
99

1010

1111
def test_custom_shadow_variable():
12-
@variable_listener
13-
class MyVariableListener:
14-
def afterVariableChanged(self, score_director, entity):
15-
score_director.beforeVariableChanged(entity, 'value_squared')
12+
class MyVariableListener(VariableListener):
13+
def after_variable_changed(self, score_director: ScoreDirector, entity):
14+
score_director.before_variable_changed(entity, 'value_squared')
1615
if entity.value is None:
1716
entity.value_squared = None
1817
else:
1918
entity.value_squared = entity.value ** 2
20-
score_director.afterVariableChanged(entity, 'value_squared')
21-
22-
def beforeVariableChanged(self, score_director, entity):
23-
pass
24-
25-
def beforeEntityAdded(self, score_director, entity):
26-
pass
27-
28-
def afterEntityAdded(self, score_director, entity):
29-
pass
30-
31-
def beforeEntityRemoved(self, score_director, entity):
32-
pass
33-
34-
def afterEntityRemoved(self, score_director, entity):
35-
pass
19+
score_director.after_variable_changed(entity, 'value_squared')
3620

3721
@planning_entity
3822
@dataclass
@@ -79,34 +63,18 @@ class MySolution:
7963

8064

8165
def test_custom_shadow_variable_with_variable_listener_ref():
82-
@variable_listener
83-
class MyVariableListener:
84-
def afterVariableChanged(self, score_director, entity):
85-
score_director.beforeVariableChanged(entity, 'twice_value')
86-
score_director.beforeVariableChanged(entity, 'value_squared')
66+
class MyVariableListener(VariableListener):
67+
def after_variable_changed(self, score_director: ScoreDirector, entity):
68+
score_director.before_variable_changed(entity, 'twice_value')
69+
score_director.before_variable_changed(entity, 'value_squared')
8770
if entity.value is None:
8871
entity.twice_value = None
8972
entity.value_squared = None
9073
else:
9174
entity.twice_value = 2 * entity.value
9275
entity.value_squared = entity.value ** 2
93-
score_director.afterVariableChanged(entity, 'value_squared')
94-
score_director.afterVariableChanged(entity, 'twice_value')
95-
96-
def beforeVariableChanged(self, score_director, entity):
97-
pass
98-
99-
def beforeEntityAdded(self, score_director, entity):
100-
pass
101-
102-
def afterEntityAdded(self, score_director, entity):
103-
pass
104-
105-
def beforeEntityRemoved(self, score_director, entity):
106-
pass
107-
108-
def afterEntityRemoved(self, score_director, entity):
109-
pass
76+
score_director.after_variable_changed(entity, 'value_squared')
77+
score_director.after_variable_changed(entity, 'twice_value')
11078

11179
@planning_entity
11280
@dataclass
@@ -115,9 +83,8 @@ class MyPlanningEntity:
11583
field(default=None)
11684
value_squared: Annotated[Optional[int], ShadowVariable(
11785
variable_listener_class=MyVariableListener, source_variable_name='value')] = field(default=None)
118-
# TODO: Use PiggyBackShadowVariable
119-
twice_value: Annotated[Optional[int], ShadowVariable(
120-
variable_listener_class=MyVariableListener, source_variable_name='value')] = field(default=None)
86+
twice_value: Annotated[Optional[int], PiggybackShadowVariable(shadow_variable_name='value_squared')] = (
87+
field(default=None))
12188

12289
@constraint_provider
12390
def my_constraints(constraint_factory: ConstraintFactory):

timefold-solver-python-core/src/main/python/_timefold_java_interop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def get_class(python_class: Union[Type, Callable]) -> JClass:
116116
from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference
117117
from jpyinterpreter import is_c_native, get_java_type_for_python_type
118118

119+
if python_class is None:
120+
return cast(JClass, None)
119121
if isinstance(python_class, jpype.JClass):
120122
return cast(JClass, python_class).class_
121123
if isinstance(python_class, Class):

timefold-solver-python-core/src/main/python/annotation/_annotations.py

Lines changed: 26 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import jpype
22

3+
from ..api import VariableListener
34
from ..constraint import ConstraintFactory
45
from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class
56
from jpyinterpreter import JavaAnnotation
@@ -11,8 +12,7 @@
1112
from ai.timefold.solver.core.api.score.stream import Constraint as _Constraint
1213
from ai.timefold.solver.core.api.score import Score as _Score
1314
from ai.timefold.solver.core.api.score.calculator import IncrementalScoreCalculator as _IncrementalScoreCalculator
14-
from ai.timefold.solver.core.api.domain.variable import PlanningVariableGraphType as _PlanningVariableGraphType, \
15-
VariableListener as _VariableListener
15+
from ai.timefold.solver.core.api.domain.variable import PlanningVariableGraphType as _PlanningVariableGraphType
1616

1717

1818
Solution_ = TypeVar('Solution_')
@@ -35,7 +35,7 @@ class PlanningVariable(JavaAnnotation):
3535
def __init__(self, *,
3636
value_range_provider_refs: List[str] = None,
3737
allows_unassigned: bool = False,
38-
graph_type: '_PlanningVariableGraphType' = None):
38+
graph_type=None):
3939
ensure_init()
4040
from ai.timefold.solver.core.api.domain.variable import PlanningVariable as JavaPlanningVariable
4141
super().__init__(JavaPlanningVariable,
@@ -75,19 +75,37 @@ def __init__(self, *,
7575

7676
class ShadowVariable(JavaAnnotation):
7777
def __init__(self, *,
78-
variable_listener_class: Type['_VariableListener'] = None,
78+
variable_listener_class: Type[VariableListener] = None,
7979
source_variable_name: str,
8080
source_entity_class: Type = None):
8181
ensure_init()
8282
from .._timefold_java_interop import get_class
83+
from jpyinterpreter import get_java_type_for_python_type
8384
from ai.timefold.jpyinterpreter import PythonClassTranslator
8485
from ai.timefold.solver.core.api.domain.variable import (
85-
ShadowVariable as JavaShadowVariable)
86+
ShadowVariable as JavaShadowVariable, VariableListener as JavaVariableListener)
87+
8688
super().__init__(JavaShadowVariable,
8789
{
8890
'variableListenerClass': get_class(variable_listener_class),
8991
'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name),
90-
'sourceEntityClass': source_entity_class,
92+
'sourceEntityClass': get_class(source_entity_class),
93+
})
94+
95+
96+
class PiggybackShadowVariable(JavaAnnotation):
97+
def __init__(self, *,
98+
shadow_variable_name: str,
99+
shadow_entity_class: Type = None):
100+
ensure_init()
101+
from .._timefold_java_interop import get_class
102+
from ai.timefold.jpyinterpreter import PythonClassTranslator
103+
from ai.timefold.solver.core.api.domain.variable import (
104+
PiggybackShadowVariable as JavaPiggybackShadowVariable)
105+
super().__init__(JavaPiggybackShadowVariable,
106+
{
107+
'shadowVariableName': PythonClassTranslator.getJavaFieldName(shadow_variable_name),
108+
'shadowEntityClass': get_class(shadow_entity_class),
91109
})
92110

93111

@@ -455,100 +473,6 @@ def resetWorkingSolution(self, workingSolution: Solution_, constraintMatchEnable
455473
return register_java_class(incremental_score_calculator, java_class)
456474

457475

458-
def variable_listener(variable_listener_class: Type['_VariableListener'] = None, /, *,
459-
require_unique_entity_events: bool = False) -> Type['_VariableListener']:
460-
"""Changes shadow variables when a genuine planning variable changes.
461-
Important: it must only change the shadow variable(s) for which it's configured!
462-
It should never change a genuine variable or a problem fact.
463-
It can change its shadow variable(s) on multiple entity instances
464-
(for example: an arrival_time change affects all trailing entities too).
465-
466-
It is recommended that implementations be kept stateless.
467-
If state must be implemented, implementations may need to override the default methods
468-
resetWorkingSolution(score_director: ScoreDirector) and close().
469-
470-
The following methods must exist:
471-
472-
def beforeEntityAdded(score_director: ScoreDirector[Solution_], entity: Entity_);
473-
474-
def afterEntityAdded(score_director: ScoreDirector[Solution_], entity: Entity_);
475-
476-
def beforeEntityRemoved(score_director: ScoreDirector[Solution_], entity: Entity_);
477-
478-
def afterEntityRemoved(score_director: ScoreDirector[Solution_], entity: Entity_);
479-
480-
def beforeVariableChanged(score_director: ScoreDirector[Solution_], entity: Entity_);
481-
482-
def afterVariableChanged(score_director: ScoreDirector[Solution_], entity: Entity_);
483-
484-
If the implementation is stateful, then the following methods should also be defined:
485-
486-
def resetWorkingSolution(score_director: ScoreDirector)
487-
488-
def close()
489-
490-
:param require_unique_entity_events: Set to True to guarantee that each of the before/after methods will only be
491-
called once per entity instance per operation type (add, change or remove).
492-
When set to True, this has a slight performance loss.
493-
When set to False, it's often easier to make the listener implementation
494-
correct and fast.
495-
Defaults to False
496-
497-
:type variable_listener_class: '_VariableListener'
498-
:type require_unique_entity_events: bool
499-
:rtype: Type
500-
"""
501-
ensure_init()
502-
503-
def variable_listener_wrapper(the_variable_listener_class):
504-
from jpyinterpreter import translate_python_class_to_java_class, generate_proxy_class_for_translated_class
505-
from ai.timefold.solver.core.api.domain.variable import VariableListener
506-
methods = ['beforeEntityAdded',
507-
'afterEntityAdded',
508-
'beforeVariableChanged',
509-
'afterVariableChanged',
510-
'beforeEntityRemoved',
511-
'afterEntityRemoved']
512-
513-
missing_method_list = []
514-
for method in methods:
515-
if not callable(getattr(the_variable_listener_class, method, None)):
516-
missing_method_list.append(method)
517-
if len(missing_method_list) != 0:
518-
raise ValueError(f'The following required methods are missing from @variable_listener class '
519-
f'{the_variable_listener_class}: {missing_method_list}')
520-
521-
method_on_class = getattr(the_variable_listener_class, 'requiresUniqueEntityEvents', None)
522-
if method_on_class is None:
523-
def class_requires_unique_entity_events(self):
524-
return require_unique_entity_events
525-
526-
setattr(the_variable_listener_class, 'requiresUniqueEntityEvents', class_requires_unique_entity_events)
527-
528-
method_on_class = getattr(the_variable_listener_class, 'close', None)
529-
if method_on_class is None:
530-
def close(self):
531-
pass
532-
533-
setattr(the_variable_listener_class, 'close', close)
534-
535-
method_on_class = getattr(the_variable_listener_class, 'resetWorkingSolution', None)
536-
if method_on_class is None:
537-
def reset_working_solution(self, score_director):
538-
pass
539-
540-
setattr(the_variable_listener_class, 'resetWorkingSolution', reset_working_solution)
541-
542-
translated_class = translate_python_class_to_java_class(the_variable_listener_class)
543-
java_class = generate_proxy_class_for_translated_class(VariableListener, translated_class)
544-
return register_java_class(the_variable_listener_class, java_class)
545-
546-
if variable_listener_class: # Called as @variable_listener
547-
return variable_listener_wrapper(variable_listener_class)
548-
else: # Called as @variable_listener(require_unique_entity_events=True)
549-
return variable_listener_wrapper
550-
551-
552476
def problem_change(problem_change_class: Type['_ProblemChange']) -> \
553477
Type['_ProblemChange']:
554478
"""A ProblemChange represents a change in 1 or more planning entities or problem facts of a PlanningSolution.
@@ -599,6 +523,7 @@ def wrapper_doChange(self, solution, problem_change_director):
599523

600524
__all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable',
601525
'PlanningListVariable', 'PlanningVariableReference', 'ShadowVariable',
526+
'PiggybackShadowVariable',
602527
'IndexShadowVariable', 'AnchorShadowVariable', 'InverseRelationShadowVariable',
603528
'ProblemFactProperty', 'ProblemFactCollectionProperty',
604529
'PlanningEntityProperty', 'PlanningEntityCollectionProperty',
@@ -607,4 +532,4 @@ def wrapper_doChange(self, solution, problem_change_director):
607532
'planning_entity', 'planning_solution', 'constraint_configuration',
608533
'nearby_distance_meter',
609534
'constraint_provider', 'easy_score_calculator', 'incremental_score_calculator',
610-
'variable_listener', 'problem_change']
535+
'problem_change']

timefold-solver-python-core/src/main/python/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
from ._solver_factory import *
44
from ._solver_manager import *
55
from ._solution_manager import *
6+
from ._score_director import *
7+
from ._variable_listener import *
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
class ScoreDirector:
2+
def __init__(self, delegate):
3+
self._delegate = delegate
4+
5+
def after_entity_added(self, entity) -> None:
6+
self._delegate.afterEntityAdded(entity)
7+
8+
def after_entity_removed(self, entity) -> None:
9+
self._delegate.afterEntityRemoved(entity)
10+
11+
def after_list_variable_changed(self, entity, variable_name: str, start: int, end: int) -> None:
12+
self._delegate.afterListVariableChanged(entity, variable_name, start, end)
13+
14+
def after_list_variable_element_assigned(self, entity, variable_name: str, element) -> None:
15+
self._delegate.afterListVariableElementAssigned(entity, variable_name, element)
16+
17+
def after_list_variable_element_unassigned(self, entity, variable_name: str, element) -> None:
18+
self._delegate.afterListVariableElementUnassigned(entity, variable_name, element)
19+
20+
def after_problem_fact_added(self, entity) -> None:
21+
self._delegate.afterProblemFactAdded(entity)
22+
23+
def after_problem_fact_removed(self, entity) -> None:
24+
self._delegate.afterProblemFactRemoved(entity)
25+
26+
def after_problem_property_changed(self, entity) -> None:
27+
self._delegate.afterProblemPropertyChanged(entity)
28+
29+
def after_variable_changed(self, entity, variable_name: str) -> None:
30+
self._delegate.afterVariableChanged(entity, variable_name)
31+
32+
def before_entity_added(self, entity) -> None:
33+
self._delegate.beforeEntityAdded(entity)
34+
35+
def before_entity_removed(self, entity) -> None:
36+
self._delegate.beforeEntityRemoved(entity)
37+
38+
def before_list_variable_changed(self, entity, variable_name: str, start: int, end: int) -> None:
39+
self._delegate.beforeListVariableChanged(entity, variable_name, start, end)
40+
41+
def before_list_variable_element_assigned(self, entity, variable_name: str, element) -> None:
42+
self._delegate.beforeListVariableElementAssigned(entity, variable_name, element)
43+
44+
def before_list_variable_element_unassigned(self, entity, variable_name: str, element) -> None:
45+
self._delegate.beforeListVariableElementUnassigned(entity, variable_name, element)
46+
47+
def before_problem_fact_added(self, entity) -> None:
48+
self._delegate.beforeProblemFactAdded(entity)
49+
50+
def before_problem_fact_removed(self, entity) -> None:
51+
self._delegate.beforeProblemFactRemoved(entity)
52+
53+
def before_problem_property_changed(self, entity) -> None:
54+
self._delegate.beforeProblemPropertyChanged(entity)
55+
56+
def before_variable_changed(self, entity, variable_name: str) -> None:
57+
self._delegate.beforeVariableChanged(entity, variable_name)
58+
59+
def get_working_solution(self):
60+
return self._delegate.getWorkingSolution()
61+
62+
def look_up_working_object(self, working_object):
63+
return self._delegate.lookUpWorkingObject(working_object)
64+
65+
def look_up_working_object_or_return_none(self, working_object):
66+
return self._delegate.lookUpWorkingObject(working_object)
67+
68+
def trigger_variable_listeners(self) -> None:
69+
self._delegate.triggerVariableListeners()
70+
71+
72+
__all__ = ['ScoreDirector']

0 commit comments

Comments
 (0)