Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

Commit 97d496b

Browse files
committed
wip: add load_balance test
1 parent 552b307 commit 97d496b

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

tests/test_collectors.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -563,24 +563,39 @@ def define_constraints(constraint_factory: ConstraintFactory):
563563
assert score_manager.explain(problem).score == SimpleScore.of(4)
564564

565565

566-
def test_flatten_last():
566+
def test_load_balance():
567567
@constraint_provider
568568
def define_constraints(constraint_factory: ConstraintFactory):
569569
return [
570570
constraint_factory.for_each(Entity)
571-
.map(lambda entity: (1, 2, 3))
572-
.flatten_last(lambda the_tuple: the_tuple)
573-
.reward(SimpleScore.ONE)
574-
.as_constraint('Count')
571+
.group_by(ConstraintCollectors.load_balance(
572+
lambda entity: entity.value
573+
))
574+
.reward(SimpleScore.ONE,
575+
lambda balance: balance.unfairness().multiply(BigDecimal.valueOf(1000)).intValue())
576+
.as_constraint('Balanced value')
575577
]
576578

577579
score_manager = create_score_manager(define_constraints)
578580

579581
entity_a: Entity = Entity('A')
582+
entity_b: Entity = Entity('B')
583+
entity_c: Entity = Entity('C')
580584

581585
value_1 = Value(1)
586+
value_2 = Value(2)
582587

583-
problem = Solution([entity_a], [value_1])
588+
problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2])
584589
entity_a.value = value_1
590+
entity_b.value = value_1
591+
entity_c.value = value_1
585592

586-
assert score_manager.explain(problem).score == SimpleScore.of(3)
593+
assert score_manager.explain(problem).score == SimpleScore.of(0)
594+
595+
entity_c.value = value_2
596+
597+
assert score_manager.explain(problem).score == SimpleScore.of(2) // FIXME
598+
599+
entity_b.value = value_2
600+
601+
assert score_manager.explain(problem).score == SimpleScore.of(4) // FIXME

tests/test_constraint_streams.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,29 @@ def define_constraints(constraint_factory: ConstraintFactory):
223223
assert score_manager.explain(problem).score.score == 1
224224

225225

226+
def test_flatten_last():
227+
@constraint_provider
228+
def define_constraints(constraint_factory: ConstraintFactory):
229+
return [
230+
constraint_factory.for_each(Entity)
231+
.map(lambda entity: (1, 2, 3))
232+
.flatten_last(lambda the_tuple: the_tuple)
233+
.reward(SimpleScore.ONE)
234+
.as_constraint('Count')
235+
]
236+
237+
score_manager = create_score_manager(define_constraints)
238+
239+
entity_a: Entity = Entity('A')
240+
241+
value_1 = Value(1)
242+
243+
problem = Solution([entity_a], [value_1])
244+
entity_a.value = value_1
245+
246+
assert score_manager.explain(problem).score == SimpleScore.of(3)
247+
248+
226249
def test_join_uni():
227250
@constraint_provider
228251
def define_constraints(constraint_factory: ConstraintFactory):

timefold-solver-python-core/src/main/python/score/_group_by.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ def perform_group_by(constraint_stream, package, group_by_args, *type_arguments)
122122
created_collector = extract_collector(collector_info, *type_arguments)
123123
actual_group_by_args.append(created_collector)
124124

125-
if len(group_by_args) == 1:
125+
if len(group_by_args) is 1:
126126
return UniConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package,
127127
JClass('java.lang.Object'))
128-
elif len(group_by_args) == 2:
128+
elif len(group_by_args) is 2:
129129
return BiConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package,
130130
JClass('java.lang.Object'),
131131
JClass('java.lang.Object'))
132-
elif len(group_by_args) == 3:
132+
elif len(group_by_args) is 3:
133133
return TriConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package,
134134
JClass('java.lang.Object'),
135135
JClass('java.lang.Object'), JClass('java.lang.Object'))
136-
elif len(group_by_args) == 4:
136+
elif len(group_by_args) is 4:
137137
return QuadConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package,
138138
JClass('java.lang.Object'),
139139
JClass('java.lang.Object'), JClass('java.lang.Object'),
@@ -1107,27 +1107,28 @@ def load_balance(balanced_item_function, load_function=None, initial_load_functi
11071107
11081108
Parameters
11091109
----------
1110-
balanced_item_function:
1110+
balanced_item_function: Callable[[ParameterTypes, ...], Balanced_]
11111111
The function that returns the item which should be load-balanced.
1112-
load_function:
1112+
load_function: Callable[[ParameterTypes, ...], int]
11131113
How much the item should count for in the formula.
1114-
initial_load_function:
1114+
initial_load_function: Callable[[ParameterTypes, ...], int]
11151115
The initial value of the metric, allowing to provide initial state
11161116
without requiring the entire previous planning windows in the working memory.
11171117
If this function is provided, load_function must be provided as well.
11181118
"""
1119-
if None == load_function == initial_load_function:
1119+
if load_function is None and initial_load_function is None:
11201120
return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, None,
11211121
None)
1122-
elif None == initial_load_function:
1122+
elif initial_load_function is None:
11231123
return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function,
11241124
load_function, None)
1125-
elif None == load_function:
1125+
elif load_function is None:
11261126
raise ValueError("load_function cannot be None if initial_load_function is not None")
11271127
else:
11281128
return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function,
11291129
load_function, initial_load_function)
11301130

1131+
11311132
# Must be at the bottom, constraint_stream depends on this module
11321133
from ._constraint_stream import *
11331134
from ._function_translator import *

0 commit comments

Comments
 (0)