Skip to content

Commit e60cd28

Browse files
committed
chore: add if(Not)Exists overloads
1 parent efecb6a commit e60cd28

File tree

2 files changed

+302
-121
lines changed

2 files changed

+302
-121
lines changed

tests/test_constraint_streams.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,87 @@ def define_constraints(constraint_factory: ConstraintFactory):
265265
assert score_manager.explain(problem).score.score == 8
266266

267267

268+
def test_if_exists_uni():
269+
@constraint_provider
270+
def define_constraints(constraint_factory: ConstraintFactory):
271+
return [
272+
constraint_factory.for_each(Entity)
273+
.if_exists(Entity, Joiners.equal(lambda entity: entity.code))
274+
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
275+
.as_constraint('Count')
276+
]
277+
278+
score_manager = create_score_manager(define_constraints)
279+
entity_a1: Entity = Entity('A')
280+
entity_a2: Entity = Entity('A')
281+
entity_b1: Entity = Entity('B')
282+
entity_b2: Entity = Entity('B')
283+
284+
value_1 = Value(1)
285+
value_2 = Value(2)
286+
287+
problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])
288+
289+
entity_a1.value = value_1
290+
291+
# With itself
292+
assert score_manager.explain(problem).score.score == 1
293+
294+
entity_a1.value = value_1
295+
entity_a2.value = value_1
296+
297+
entity_b1.value = value_2
298+
entity_b2.value = value_2
299+
300+
# 1 + 2 + 1 + 2
301+
assert score_manager.explain(problem).score.score == 6
302+
303+
entity_a1.value = value_2
304+
entity_b1.value = value_1
305+
306+
# 1 + 2 + 1 + 2
307+
assert score_manager.explain(problem).score.score == 6
308+
309+
310+
def test_if_not_exists_uni():
311+
@constraint_provider
312+
def define_constraints(constraint_factory: ConstraintFactory):
313+
return [
314+
constraint_factory.for_each(Entity)
315+
.if_not_exists(Entity, Joiners.equal(lambda entity: entity.code))
316+
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
317+
.as_constraint('Count')
318+
]
319+
320+
score_manager = create_score_manager(define_constraints)
321+
entity_a1: Entity = Entity('A')
322+
entity_a2: Entity = Entity('A')
323+
entity_b1: Entity = Entity('B')
324+
entity_b2: Entity = Entity('B')
325+
326+
value_1 = Value(1)
327+
value_2 = Value(2)
328+
329+
problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])
330+
331+
entity_a1.value = value_1
332+
333+
assert score_manager.explain(problem).score.score == 0
334+
335+
entity_a1.value = value_1
336+
entity_a2.value = value_1
337+
338+
entity_b1.value = value_2
339+
entity_b2.value = value_2
340+
341+
assert score_manager.explain(problem).score.score == 0
342+
343+
entity_a1.value = value_2
344+
entity_b1.value = value_1
345+
346+
assert score_manager.explain(problem).score.score == 0
347+
348+
268349
def test_map():
269350
@constraint_provider
270351
def define_constraints(constraint_factory: ConstraintFactory):

0 commit comments

Comments
 (0)