Skip to content

Commit ba56c9a

Browse files
feat: Add SolverConfigOverride support (#36)
- Added missing type info to Solver API Classes - Unlike Java, the SolverConfig in Python is Generic, since: - Python users do not need to specify types of variables, and no warnings are emitted for using a raw type - Allows a smart enough type checker to deduce the generic type of a SolverFactory, SolverManager, etc. from the generic type of the SolverConfig
1 parent 244e1d9 commit ba56c9a

File tree

7 files changed

+188
-52
lines changed

7 files changed

+188
-52
lines changed

tests/test_solver_factory.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from timefold.solver.api import *
2+
from timefold.solver.annotation import *
3+
from timefold.solver.config import *
4+
from timefold.solver.constraint import *
5+
from timefold.solver.score import *
6+
7+
from dataclasses import dataclass, field
8+
from typing import Annotated, List
9+
10+
11+
@planning_entity
12+
@dataclass
13+
class Entity:
14+
code: Annotated[str, PlanningId]
15+
value: Annotated[int, PlanningVariable] = field(default=None, compare=False)
16+
17+
18+
@constraint_provider
19+
def my_constraints(constraint_factory: ConstraintFactory):
20+
return [
21+
constraint_factory.for_each(Entity)
22+
.reward(SimpleScore.ONE, lambda entity: entity.value)
23+
.as_constraint('Maximize value'),
24+
]
25+
26+
27+
@planning_solution
28+
@dataclass
29+
class Solution:
30+
entities: Annotated[List[Entity], PlanningEntityCollectionProperty]
31+
value_range: Annotated[List[int], ValueRangeProvider]
32+
score: Annotated[SimpleScore, PlanningScore] = field(default=None)
33+
34+
def __str__(self) -> str:
35+
return str(self.entities)
36+
37+
38+
def test_solver_config_override():
39+
solver_config = SolverConfig(
40+
solution_class=Solution,
41+
entity_class_list=[Entity],
42+
score_director_factory_config=ScoreDirectorFactoryConfig(
43+
constraint_provider_function=my_constraints,
44+
),
45+
termination_config=TerminationConfig(
46+
best_score_limit='9'
47+
)
48+
)
49+
solver_factory = SolverFactory.create(solver_config)
50+
solver = solver_factory.build_solver(SolverConfigOverride(
51+
termination_config=TerminationConfig(
52+
best_score_limit='3'
53+
)
54+
))
55+
problem = Solution([Entity('A')], [1, 2, 3])
56+
solution = solver.solve(problem)
57+
assert solution.score.score() == 3

tests/test_solver_manager.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,59 @@ def my_exception_handler(problem_id, exception):
272272

273273
assert the_problem_id == 1
274274
assert the_exception is not None
275+
276+
277+
def test_solver_config_override():
278+
@dataclass
279+
class Value:
280+
value: Annotated[int, PlanningId]
281+
282+
@planning_entity
283+
@dataclass
284+
class Entity:
285+
code: Annotated[str, PlanningId]
286+
value: Annotated[Value, PlanningVariable] = field(default=None)
287+
288+
@constraint_provider
289+
def my_constraints(constraint_factory: ConstraintFactory):
290+
return [
291+
constraint_factory.for_each(Entity)
292+
.reward(SimpleScore.ONE, lambda entity: entity.value.value)
293+
.as_constraint('Maximize Value')
294+
]
295+
296+
@planning_solution
297+
@dataclass
298+
class Solution:
299+
entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty]
300+
value_list: Annotated[List[Value],
301+
DeepPlanningClone,
302+
ProblemFactCollectionProperty,
303+
ValueRangeProvider]
304+
score: Annotated[SimpleScore, PlanningScore] = field(default=None)
305+
306+
solver_config = SolverConfig(
307+
solution_class=Solution,
308+
entity_class_list=[Entity],
309+
score_director_factory_config=ScoreDirectorFactoryConfig(
310+
constraint_provider_function=my_constraints
311+
),
312+
termination_config=TerminationConfig(
313+
best_score_limit='9'
314+
)
315+
)
316+
problem: Solution = Solution([Entity('A')], [Value(1), Value(2), Value(3)],
317+
SimpleScore.ONE)
318+
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
319+
solver_job = (solver_manager.solve_builder()
320+
.with_problem_id(1)
321+
.with_problem(problem)
322+
.with_config_override(SolverConfigOverride(
323+
termination_config=TerminationConfig(
324+
best_score_limit='3'
325+
)
326+
))
327+
.run())
328+
329+
solution = solver_job.get_final_best_solution()
330+
assert solution.score.score() == 3

timefold-solver-python-core/src/main/python/api/_solution_manager.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._solver_factory import SolverFactory
22
from .._timefold_java_interop import get_class
33

4-
from typing import TypeVar, Union, TYPE_CHECKING
4+
from typing import TypeVar, Generic, Union, TYPE_CHECKING
55

66

77
if TYPE_CHECKING:
@@ -15,41 +15,43 @@
1515
ProblemId_ = TypeVar('ProblemId_')
1616

1717

18-
class SolutionManager:
18+
class SolutionManager(Generic[Solution_]):
1919
_delegate: '_JavaSolutionManager'
2020

2121
def __init__(self, delegate: '_JavaSolutionManager'):
2222
self._delegate = delegate
2323

2424
@staticmethod
25-
def create(solver_factory: 'SolverFactory'):
25+
def create(solver_factory: 'SolverFactory[Solution_]') -> 'SolutionManager[Solution_]':
2626
from ai.timefold.solver.core.api.solver import SolutionManager as JavaSolutionManager
2727
return SolutionManager(JavaSolutionManager.create(solver_factory._delegate))
2828

29-
def update(self, solution, solution_update_policy=None) -> 'Score':
29+
def update(self, solution: Solution_, solution_update_policy=None) -> 'Score':
3030
# TODO handle solution_update_policy
3131
from jpyinterpreter import convert_to_java_python_like_object, update_python_object_from_java
3232
java_solution = convert_to_java_python_like_object(solution)
3333
out = self._delegate.update(java_solution)
3434
update_python_object_from_java(java_solution)
3535
return out
3636

37-
def analyze(self, solution, score_analysis_fetch_policy=None, solution_update_policy=None) -> 'ScoreAnalysis':
37+
def analyze(self, solution: Solution_, score_analysis_fetch_policy=None, solution_update_policy=None) \
38+
-> 'ScoreAnalysis':
3839
# TODO handle policies
3940
from jpyinterpreter import convert_to_java_python_like_object
4041
return ScoreAnalysis(self._delegate.analyze(convert_to_java_python_like_object(solution)))
4142

42-
def explain(self, solution, solution_update_policy=None) -> 'ScoreExplanation':
43+
def explain(self, solution: Solution_, solution_update_policy=None) -> 'ScoreExplanation':
4344
# TODO handle policies
4445
from jpyinterpreter import convert_to_java_python_like_object
4546
return ScoreExplanation(self._delegate.explain(convert_to_java_python_like_object(solution)))
4647

47-
def recommend_fit(self, solution, entity_or_element, proposition_function, score_analysis_fetch_policy=None):
48+
def recommend_fit(self, solution: Solution_, entity_or_element, proposition_function,
49+
score_analysis_fetch_policy=None):
4850
# TODO
4951
raise NotImplementedError
5052

5153

52-
class ScoreExplanation:
54+
class ScoreExplanation(Generic[Solution_]):
5355
_delegate: '_JavaScoreExplanation'
5456

5557
def __init__(self, delegate: '_JavaScoreExplanation'):
@@ -70,7 +72,7 @@ def get_justification_list(self, justification_type=None):
7072
def get_score(self) -> 'Score':
7173
return self._delegate.getScore()
7274

73-
def get_solution(self):
75+
def get_solution(self) -> Solution_:
7476
from jpyinterpreter import unwrap_python_like_object
7577
return unwrap_python_like_object(self._delegate.getSolution())
7678

timefold-solver-python-core/src/main/python/api/_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def terminate_early(self) -> bool:
5959
def is_terminate_early(self) -> bool:
6060
return self._delegate.isTerminateEarly()
6161

62-
def add_problem_change(self, problem_change: ProblemChange) -> None:
62+
def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> None:
6363
self._delegate.addProblemChange(ProblemChangeWrapper(problem_change)) # noqa
6464

65-
def add_problem_changes(self, problem_changes: List[ProblemChange]) -> None:
65+
def add_problem_changes(self, problem_changes: List[ProblemChange[Solution_]]) -> None:
6666
self._delegate.addProblemChanges([ProblemChangeWrapper(problem_change) for problem_change in problem_changes]) # noqa
6767

6868
def is_every_problem_change_processed(self) -> bool:
Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._solver import Solver
2-
from ..config import SolverConfig
2+
from ..config import SolverConfig, SolverConfigOverride
33

4-
from typing import TypeVar, TYPE_CHECKING
4+
from typing import TypeVar, Generic, TYPE_CHECKING
55
from jpype import JClass
66

77
if TYPE_CHECKING:
@@ -12,7 +12,7 @@
1212
Solution_ = TypeVar('Solution_')
1313

1414

15-
class SolverFactory:
15+
class SolverFactory(Generic[Solution_]):
1616
_delegate: '_JavaSolverFactory'
1717
_solution_class: JClass
1818

@@ -21,14 +21,18 @@ def __init__(self, delegate: '_JavaSolverFactory', solution_class: JClass):
2121
self._solution_class = solution_class
2222

2323
@staticmethod
24-
def create(solver_config: SolverConfig):
24+
def create(solver_config: SolverConfig[Solution_]) -> 'SolverFactory[Solution_]':
2525
from ai.timefold.solver.core.api.solver import SolverFactory as JavaSolverFactory
2626
solver_config = solver_config._to_java_solver_config()
2727
delegate = JavaSolverFactory.create(solver_config) # noqa
2828
return SolverFactory(delegate, solver_config.getSolutionClass()) # noqa
2929

30-
def build_solver(self):
31-
return Solver(self._delegate.buildSolver(), self._solution_class)
30+
def build_solver(self, solver_config_override: SolverConfigOverride = None) -> Solver[Solution_]:
31+
if solver_config_override is None:
32+
return Solver(self._delegate.buildSolver(), self._solution_class)
33+
else:
34+
return Solver(self._delegate.buildSolver(solver_config_override._to_java_solver_config_override()),
35+
self._solution_class)
3236

3337

3438
__all__ = ['SolverFactory']

0 commit comments

Comments
 (0)