Skip to content

Commit fb72d9c

Browse files
feat: Add support for problem changes (#34)
* feat: Add support for problem changes - Since the methods in ProblemChangeDirector take interfaces, ProblemChange as a whole cannot be translated to pure Java (since that requires supporting casting an arbitary callable to any Java interface, which is not suported yet (nor planned to be supported)). - Thus we goes for a more ugly approach: the ProblemChange runs in Python, and when a problem change director method is called, we compile/translate the supplied function to Java. - We do a trick where we replace the Python working solution clone with the actual Java working solution in the closure before compiling the function so changes are applied to the right object. After the method is called, we then update the Python working solution from the java working solution so changes are reflected in it too. - Users implement a ProblemChange by extending an abstract base class (which, among other things, raises an error if not all of its methods are implemented).
1 parent 30a13b9 commit fb72d9c

File tree

8 files changed

+379
-67
lines changed

8 files changed

+379
-67
lines changed

jpyinterpreter/src/main/python/conversions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,10 @@ def is_c_native(item):
117117
or module == '': # if we cannot find module, assume it is not native
118118
return False
119119

120-
return is_native_module(importlib.import_module(module))
120+
try:
121+
return is_native_module(importlib.import_module(module))
122+
except:
123+
return True
121124

122125

123126
def init_type_to_compiled_java_class():

tests/test_solver_manager.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Value:
2626
@dataclass
2727
class Entity:
2828
code: Annotated[str, PlanningId]
29-
value: Annotated[Value, PlanningVariable] = field(default=None)
29+
value: Annotated[Value, PlanningVariable] = field(default=None, compare=False)
3030

3131
@constraint_provider
3232
def my_constraints(constraint_factory: ConstraintFactory):
@@ -54,24 +54,22 @@ class Solution:
5454
ValueRangeProvider]
5555
score: Annotated[SimpleScore, PlanningScore] = field(default=None)
5656

57-
# TODO: Support problem changes
58-
# @Problem_change
59-
# class UseOnlyEntityAndValueProblemChange:
60-
# def __init__(self, entity, value):
61-
# self.entity = entity
62-
# self.value = value
63-
#
64-
# def doChange(self, solution: Solution, problem_change_director: timefold.solver.types.ProblemChangeDirector):
65-
# problem_facts_to_remove = solution.value_list.copy()
66-
# entities_to_remove = solution.entity_list.copy()
67-
# for problem_fact in problem_facts_to_remove:
68-
# problem_change_director.removeProblemFact(problem_fact,
69-
# lambda value: solution.value_list.remove(problem_fact))
70-
# for removed_entity in entities_to_remove:
71-
# problem_change_director.removeEntity(removed_entity,
72-
# lambda entity: solution.entity_list.remove(removed_entity))
73-
# problem_change_director.addEntity(self.entity, lambda entity: solution.entity_list.append(entity))
74-
# problem_change_director.addProblemFact(self.value, lambda value: solution.value_list.append(value))
57+
class UseOnlyEntityAndValueProblemChange(ProblemChange[Solution]):
58+
def __init__(self, entity, value):
59+
self.entity = entity
60+
self.value = value
61+
62+
def do_change(self, solution: Solution, problem_change_director: ProblemChangeDirector):
63+
problem_facts_to_remove = solution.value_list.copy()
64+
entities_to_remove = solution.entity_list.copy()
65+
for problem_fact in problem_facts_to_remove:
66+
problem_change_director.remove_problem_fact(problem_fact,
67+
lambda value: solution.value_list.remove(value))
68+
for removed_entity in entities_to_remove:
69+
problem_change_director.remove_entity(removed_entity,
70+
lambda entity: solution.entity_list.remove(entity))
71+
problem_change_director.add_entity(self.entity, lambda entity: solution.entity_list.append(entity))
72+
problem_change_director.add_problem_fact(self.value, lambda value: solution.value_list.append(value))
7573

7674
solver_config = SolverConfig(
7775
solution_class=Solution,
@@ -97,27 +95,27 @@ def assert_solver_run(solver_manager, solver_job):
9795
assert 3 in value_list
9896
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING
9997

100-
# def assert_problem_change_solver_run(solver_manager, solver_job):
101-
# assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
102-
# solver_manager.addProblemChange(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
103-
# lock.release()
104-
# solution = solver_job.get_final_best_solution()
105-
# assert solution.score.score() == 6
106-
# assert len(solution.entity_list) == 1
107-
# assert len(solution.value_range) == 1
108-
# assert solution.entity_list[0].code == 'D'
109-
# assert solution.entity_list[0].value.value == 6
110-
# assert solution.value_range[0].value == 6
111-
# assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING
98+
def assert_problem_change_solver_run(solver_manager, solver_job):
99+
assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
100+
solver_manager.add_problem_change(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
101+
lock.release()
102+
solution = solver_job.get_final_best_solution()
103+
assert solution.score.score() == 6
104+
assert len(solution.entity_list) == 1
105+
assert len(solution.value_list) == 1
106+
assert solution.entity_list[0].code == 'D'
107+
assert solution.entity_list[0].value.value == 6
108+
assert solution.value_list[0].value == 6
109+
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING
112110

113111
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
114112
lock.acquire()
115113
solver_job = solver_manager.solve(1, problem)
116114
assert_solver_run(solver_manager, solver_job)
117115

118-
# lock.acquire()
119-
# solver_job = solver_manager.solve(1, problem)
120-
# assert_problem_change_solver_run(solver_manager, solver_job)
116+
lock.acquire()
117+
solver_job = solver_manager.solve(1, problem)
118+
assert_problem_change_solver_run(solver_manager, solver_job)
121119

122120
def get_problem(problem_id):
123121
assert problem_id == 1
@@ -129,9 +127,11 @@ def get_problem(problem_id):
129127
.with_problem_finder(get_problem)).run()
130128
assert_solver_run(solver_manager, solver_job)
131129

132-
# lock.acquire()
133-
#solver_job = solver_manager.solve(1, get_problem)
134-
#assert_problem_change_solver_run(solver_manager, solver_job)
130+
lock.acquire()
131+
solver_job = (solver_manager.solve_builder()
132+
.with_problem_id(1)
133+
.with_problem_finder(get_problem)).run()
134+
assert_problem_change_solver_run(solver_manager, solver_job)
135135

136136
solution_list = []
137137
semaphore = Semaphore(0)
@@ -150,15 +150,16 @@ def on_best_solution_changed(solution):
150150
assert semaphore.acquire(timeout=1)
151151
assert len(solution_list) == 1
152152

153-
# solution_list = []
154-
# lock.acquire()
155-
# solver_job = (solver_manager.solve_builder()
156-
# .with_problem_id(1)
157-
# .with_problem_finder(get_problem)
158-
# .with_best_solution_consumer(on_best_solution_changed)
159-
# ).run()
160-
#assert_problem_change_solver_run(solver_manager, solver_job)
161-
# assert len(solution_list) == 1
153+
solution_list = []
154+
lock.acquire()
155+
solver_job = (solver_manager.solve_builder()
156+
.with_problem_id(1)
157+
.with_problem_finder(get_problem)
158+
.with_best_solution_consumer(on_best_solution_changed)
159+
).run()
160+
assert_problem_change_solver_run(solver_manager, solver_job)
161+
assert semaphore.acquire(timeout=1)
162+
assert len(solution_list) == 1
162163

163164
solution_list = []
164165
lock.acquire()
@@ -175,16 +176,20 @@ def on_best_solution_changed(solution):
175176
assert semaphore.acquire(timeout=1)
176177
assert len(solution_list) == 2
177178

178-
# solution_list = []
179-
# lock.acquire()
180-
# solver_job = (solver_manager.solve_builder()
181-
# .with_problem_id(1)
182-
# .with_problem_finder(get_problem)
183-
# .with_best_solution_consumer(on_best_solution_changed)
184-
# .with_final_best_solution_consumer(on_best_solution_changed)
185-
# ).run()
186-
# assert_problem_change_solver_run(solver_manager, solver_job)
187-
# assert len(solution_list) == 2
179+
solution_list = []
180+
lock.acquire()
181+
solver_job = (solver_manager.solve_builder()
182+
.with_problem_id(1)
183+
.with_problem_finder(get_problem)
184+
.with_best_solution_consumer(on_best_solution_changed)
185+
.with_final_best_solution_consumer(on_best_solution_changed)
186+
).run()
187+
assert_problem_change_solver_run(solver_manager, solver_job)
188+
# Wait for 2 acquires, one for best solution consumer,
189+
# another for final best solution consumer
190+
assert semaphore.acquire(timeout=1)
191+
assert semaphore.acquire(timeout=1)
192+
assert len(solution_list) == 2
188193

189194

190195
@pytest.mark.filterwarnings("ignore:.*Exception in thread.*:pytest.PytestUnhandledThreadExceptionWarning")

tests/test_solver_problem_change.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from timefold.solver.api import *
2+
from timefold.solver.annotation import *
3+
from timefold.solver.config import *
4+
from timefold.solver.constraint import *
5+
from timefold.solver.score import *
6+
7+
from dataclasses import dataclass, field
8+
from typing import Annotated, List
9+
from threading import Thread
10+
11+
12+
@planning_entity
13+
@dataclass
14+
class Entity:
15+
code: Annotated[str, PlanningId]
16+
value: Annotated[int, PlanningVariable] = field(default=None, compare=False)
17+
18+
19+
@constraint_provider
20+
def maximize_constraints(constraint_factory: ConstraintFactory):
21+
return [
22+
constraint_factory.for_each(Entity)
23+
.reward(SimpleScore.ONE, lambda entity: entity.value)
24+
.as_constraint('Maximize value'),
25+
]
26+
27+
28+
@constraint_provider
29+
def minimize_constraints(constraint_factory: ConstraintFactory):
30+
return [
31+
constraint_factory.for_each(Entity)
32+
.penalize(SimpleScore.ONE, lambda entity: entity.value)
33+
.as_constraint('Minimize value'),
34+
]
35+
36+
37+
@planning_solution
38+
@dataclass
39+
class Solution:
40+
entities: Annotated[List[Entity], PlanningEntityCollectionProperty]
41+
value_range: Annotated[List[int], ValueRangeProvider]
42+
score: Annotated[SimpleScore, PlanningScore] = field(default=None)
43+
44+
def __str__(self) -> str:
45+
return str(self.entities)
46+
47+
48+
class AddEntity(ProblemChange[Solution]):
49+
entity: Entity
50+
51+
def __init__(self, entity: Entity):
52+
self.entity = entity
53+
54+
def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
55+
problem_change_director.add_entity(self.entity,
56+
lambda working_entity: working_solution.entities.append(working_entity))
57+
58+
59+
class RemoveEntity(ProblemChange[Solution]):
60+
entity: Entity
61+
62+
def __init__(self, entity: Entity):
63+
self.entity = entity
64+
65+
def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
66+
problem_change_director.remove_entity(self.entity,
67+
lambda working_entity: working_solution.entities.remove(working_entity))
68+
69+
70+
def test_add_entity():
71+
solver_config = SolverConfig(
72+
solution_class=Solution,
73+
entity_class_list=[Entity],
74+
score_director_factory_config=ScoreDirectorFactoryConfig(
75+
constraint_provider_function=maximize_constraints,
76+
),
77+
termination_config=TerminationConfig(
78+
best_score_limit='6'
79+
)
80+
)
81+
82+
problem: Solution = Solution([Entity('A')], [1, 2, 3])
83+
solver = SolverFactory.create(solver_config).build_solver()
84+
result: Solution | None = None
85+
86+
def do_solve(problem: Solution):
87+
nonlocal solver, result
88+
result = solver.solve(problem)
89+
90+
thread = Thread(target=do_solve, args=(problem,), daemon=True)
91+
92+
thread.start()
93+
solver.add_problem_change(AddEntity(Entity('B')))
94+
thread.join(timeout=1)
95+
96+
if thread.is_alive():
97+
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')
98+
99+
assert result is not None
100+
assert len(result.entities) == 2
101+
assert result.score.score() == 6
102+
103+
104+
def test_remove_entity():
105+
solver_config = SolverConfig(
106+
solution_class=Solution,
107+
entity_class_list=[Entity],
108+
score_director_factory_config=ScoreDirectorFactoryConfig(
109+
constraint_provider_function=minimize_constraints,
110+
),
111+
termination_config=TerminationConfig(
112+
best_score_limit='-1'
113+
)
114+
)
115+
116+
problem: Solution = Solution([Entity('A'), Entity('B')], [1, 2, 3])
117+
solver = SolverFactory.create(solver_config).build_solver()
118+
result: Solution | None = None
119+
120+
def do_solve(problem: Solution):
121+
nonlocal solver, result
122+
result = solver.solve(problem)
123+
124+
thread = Thread(target=do_solve, args=(problem,), daemon=True)
125+
126+
thread.start()
127+
solver.add_problem_change(RemoveEntity(Entity('B')))
128+
thread.join(timeout=1)
129+
130+
if thread.is_alive():
131+
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')
132+
133+
assert result is not None
134+
assert len(result.entities) == 1
135+
assert result.score.score() == -1

timefold-solver-python-core/src/main/python/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .problem_change import *
12
from .solver import *
23
from .solver_factory import *
34
from .solver_manager import *
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from ..jpype_type_conversions import PythonBiFunction
2+
from typing import Awaitable, TypeVar, TYPE_CHECKING
3+
from asyncio import Future, get_event_loop, CancelledError
4+
5+
if TYPE_CHECKING:
6+
from java.util.concurrent import (Future as JavaFuture,
7+
CompletableFuture as JavaCompletableFuture)
8+
9+
10+
Result = TypeVar('Result')
11+
12+
13+
def wrap_future(future: 'JavaFuture[Result]') -> Awaitable[Result]:
14+
async def get_result() -> Result:
15+
nonlocal future
16+
return future.get()
17+
18+
return get_result()
19+
20+
21+
def wrap_completable_future(future: 'JavaCompletableFuture[Result]') -> Future[Result]:
22+
loop = get_event_loop()
23+
out = loop.create_future()
24+
25+
def result_handler(result, error):
26+
nonlocal out
27+
if error is not None:
28+
out.set_exception(error)
29+
else:
30+
out.set_result(result)
31+
32+
def cancel_handler(python_future: Future):
33+
nonlocal future
34+
if isinstance(python_future.exception(), CancelledError):
35+
future.cancel(True)
36+
37+
future.handle(PythonBiFunction(result_handler))
38+
out.add_done_callback(cancel_handler)
39+
return out
40+
41+
42+
__all__ = ['wrap_future', 'wrap_completable_future']

0 commit comments

Comments
 (0)