Skip to content

Commit ff466b7

Browse files
committed
v0.3.0 Greynet incremental (Rete-like) rule engine for future true incremental score calculation for GreyJack Solver
1 parent 2351597 commit ff466b7

File tree

93 files changed

+8295
-9
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+8295
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/greyjack/target
44
/data
55
/examples/object_oriented/project_job_shop
6+
/.idea
67

78
# Byte-compiled / optimized / DLL files
89
__pycache__/
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# In main_example_fixed.py, after the existing code
2+
3+
# --- 4. New Usage Example for Aggregations ---
4+
from dataclasses import dataclass
5+
from greyjack.score_calculation.greynet.builder import ConstraintBuilder, Collectors
6+
from greyjack.score_calculation.scores.SimpleScore import SimpleScore
7+
8+
@dataclass(frozen=True)
9+
class SalesTransaction:
10+
region: str
11+
amount: float
12+
13+
builder = ConstraintBuilder(name="temporal-security", score_class=SimpleScore)
14+
# Define a constraint to analyze sales data per region.
15+
# We use a penalty of 0 because the goal is data extraction, not scoring.
16+
@builder.constraint("Sales Regional Analysis")
17+
def sales_analysis():
18+
return (
19+
builder.from_facts(SalesTransaction)
20+
.group_by(
21+
lambda tx: tx.region,
22+
Collectors.compose({
23+
"min_sale": Collectors.min(lambda tx: tx.amount),
24+
"max_sale": Collectors.max(lambda tx: tx.amount),
25+
"avg_sale": Collectors.avg(lambda tx: tx.amount),
26+
"stddev_sale": Collectors.stddev(lambda tx: tx.amount),
27+
"total_sales": Collectors.sum(lambda tx: tx.amount),
28+
"num_sales": Collectors.count()
29+
})
30+
)
31+
.penalize_simple(0)
32+
)
33+
34+
# --- Execute the new analysis ---
35+
print("\n" + "="*50)
36+
print("--- Starting Sales Aggregation Example ---")
37+
print("="*50)
38+
39+
# Re-use the same builder to create a new session if needed, or add to the existing one.
40+
# For a clean test, we'll build it again.
41+
sales_session = builder.build()
42+
43+
transactions = [
44+
SalesTransaction("North", 110.0),
45+
SalesTransaction("North", 150.0),
46+
SalesTransaction("North", 195.5),
47+
SalesTransaction("South", 500.0),
48+
SalesTransaction("South", 600.0),
49+
SalesTransaction("West", 300.0),
50+
]
51+
52+
sales_session.insert_batch(transactions)
53+
sales_session.flush()
54+
55+
print("\n--- Sales Aggregation Results ---")
56+
matches = sales_session.get_constraint_matches()
57+
for constraint_id, violations in matches.items():
58+
if constraint_id == "Sales Regional Analysis":
59+
print(f"\nAnalysis: '{constraint_id}'")
60+
# Sort results by region for consistent output
61+
sorted_violations = sorted(violations, key=lambda v: v[1].fact_a)
62+
for _, facts_tuple in sorted_violations:
63+
region = facts_tuple.fact_a
64+
stats = facts_tuple.fact_b
65+
print(f" - Region: {region}")
66+
print(f" - Count: {stats['num_sales']}")
67+
print(f" - Total Sales: ${stats['total_sales']:.2f}")
68+
print(f" - Average Sale: ${stats['avg_sale']:.2f}")
69+
print(f" - Min Sale: ${stats['min_sale']:.2f}")
70+
print(f" - Max Sale: ${stats['max_sale']:.2f}")
71+
print(f" - Std Dev: ${stats['stddev_sale']:.2f}")
72+
73+
print("\n--- Sales Example Complete ---")
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# main_example.py
2+
3+
from __future__ import annotations
4+
import dataclasses
5+
from typing import Type, Callable, List
6+
from datetime import datetime, timedelta, timezone
7+
from greyjack.score_calculation.greynet.builder import ConstraintBuilder, Collectors
8+
from greyjack.score_calculation.scores.SimpleScore import SimpleScore
9+
10+
11+
# 1. Data Class Definitions (Facts)
12+
# =================================
13+
14+
@dataclasses.dataclass(frozen=True)
15+
class Sale:
16+
sale_id: str
17+
product_id: str
18+
customer_id: str
19+
price: float
20+
quantity: int
21+
timestamp: datetime
22+
23+
@dataclasses.dataclass(frozen=True)
24+
class Shipment:
25+
order_id: str
26+
shipment_id: str
27+
shipment_no: int
28+
29+
@dataclasses.dataclass(frozen=True)
30+
class Maintenance:
31+
machine_id: str
32+
start_time: datetime
33+
end_time: datetime
34+
35+
@dataclasses.dataclass(frozen=True)
36+
class UserEvent:
37+
user_id: str
38+
event_type: str
39+
value: float # e.g., transaction amount
40+
timestamp: datetime
41+
42+
43+
# 2. Constraint and Collector Definitions
44+
# =======================================
45+
46+
# Initialize the constraint builder
47+
cb = ConstraintBuilder(name="collector_showcase", score_class=SimpleScore)
48+
49+
@cb.constraint("count_total_sales_per_product")
50+
def count_collector_example():
51+
"""Demonstrates: CountCollector
52+
Counts the number of sales transactions for each product. Penalizes if a product has more than 3 sales.
53+
"""
54+
return (cb.from_facts(Sale)
55+
.group_by(lambda s: s.product_id, Collectors.count())
56+
.filter(lambda product_id, count: count > 3)
57+
.penalize_simple(lambda product_id, count: count)
58+
)
59+
60+
@cb.constraint("sum_revenue_per_product")
61+
def sum_collector_example():
62+
"""Demonstrates: SumCollector
63+
Calculates the total revenue (price * quantity) for each product.
64+
"""
65+
return (cb.from_facts(Sale)
66+
.group_by(lambda s: s.product_id, Collectors.sum(lambda s: s.price * s.quantity))
67+
.filter(lambda product_id, total_revenue: total_revenue > 0)
68+
.penalize_simple(lambda product_id, total_revenue: 0) # Use penalty 0 to just report
69+
)
70+
71+
@cb.constraint("basic_price_stats_per_product")
72+
def min_max_avg_collectors_example():
73+
"""Demonstrates: MinCollector, MaxCollector, AvgCollector
74+
Finds the minimum, maximum, and average sale price for each product.
75+
"""
76+
return (cb.from_facts(Sale)
77+
.group_by(lambda s: s.product_id, Collectors.compose({
78+
"min_price": Collectors.min(lambda s: s.price),
79+
"max_price": Collectors.max(lambda s: s.price),
80+
"avg_price": Collectors.avg(lambda s: s.price)
81+
}))
82+
.filter(lambda product_id, stats: stats["max_price"] > 1.0)
83+
.penalize_simple(lambda product_id, stats: 0) # Reporting only
84+
)
85+
86+
@cb.constraint("advanced_price_stats_per_product")
87+
def stddev_variance_collectors_example():
88+
"""Demonstrates: StdDevCollector, VarianceCollector
89+
Calculates the standard deviation and variance of prices for each product.
90+
"""
91+
return (cb.from_facts(Sale)
92+
.group_by(lambda s: s.product_id, Collectors.compose({
93+
"price_stddev": Collectors.stddev(lambda s: s.price),
94+
"price_variance": Collectors.variance(lambda s: s.price)
95+
}))
96+
.filter(lambda product_id, stats: stats["price_stddev"] > 0)
97+
.penalize_simple(lambda product_id, stats: 0) # Reporting only
98+
)
99+
100+
@cb.constraint("list_of_sales_per_product")
101+
def list_collector_example():
102+
"""Demonstrates: ListCollector
103+
Collects all `Sale` objects for each product into a list.
104+
"""
105+
return (cb.from_facts(Sale)
106+
.group_by(lambda s: s.product_id, Collectors.to_list())
107+
.filter(lambda product_id, sales_list: len(sales_list) > 0)
108+
.penalize_simple(lambda product_id, sales_list: 0) # Reporting only
109+
)
110+
111+
@cb.constraint("set_of_customers_per_product")
112+
def set_collector_example():
113+
"""Demonstrates: SetCollector and MappingCollector
114+
Collects the unique set of customer IDs for each product.
115+
"""
116+
return (cb.from_facts(Sale)
117+
.group_by(
118+
lambda s: s.product_id,
119+
Collectors.mapping(
120+
lambda s: s.customer_id,
121+
Collectors.to_set()
122+
)
123+
)
124+
.filter(lambda product_id, customer_set: len(customer_set) > 0)
125+
.penalize_simple(lambda product_id, customer_set: 0) # Reporting only
126+
)
127+
128+
@cb.constraint("distinct_list_of_customers_per_product")
129+
def distinct_collector_example():
130+
"""Demonstrates: DistinctCollector
131+
Collects a list of unique customer IDs for each product, preserving insertion order.
132+
"""
133+
return (cb.from_facts(Sale)
134+
.group_by(
135+
lambda s: s.product_id,
136+
Collectors.mapping(
137+
lambda s: s.customer_id,
138+
Collectors.distinct()
139+
)
140+
)
141+
.filter(lambda product_id, customer_list: len(customer_list) > 0)
142+
.penalize_simple(lambda product_id, customer_list: 0) # Reporting only
143+
)
144+
145+
@cb.constraint("count_high_quantity_sales")
146+
def filtering_collector_example():
147+
"""Demonstrates: FilteringCollector
148+
Counts only the sales where the quantity is greater than 2.
149+
"""
150+
return (cb.from_facts(Sale)
151+
.group_by(
152+
lambda s: s.product_id,
153+
Collectors.filtering(
154+
lambda s: s.quantity > 2,
155+
Collectors.count()
156+
)
157+
)
158+
.filter(lambda product_id, count: count > 0)
159+
.penalize_simple(lambda product_id, count: 0) # Reporting only
160+
)
161+
162+
@cb.constraint("find_consecutive_shipments")
163+
def consecutive_sequences_collector_example():
164+
"""Demonstrates: consecutive_sequences
165+
Finds consecutive sequences of shipment numbers for each order.
166+
"""
167+
return (cb.from_facts(Shipment)
168+
.group_by(
169+
lambda s: s.order_id,
170+
Collectors.consecutive_sequences(lambda s: s.shipment_no)
171+
)
172+
.filter(lambda order_id, sequences: any(seq.length > 1 for seq in sequences))
173+
.penalize_simple(lambda order_id, sequences: 0) # Reporting only
174+
)
175+
176+
@cb.constraint("find_overlapping_maintenance")
177+
def connected_ranges_collector_example():
178+
"""Demonstrates: connected_ranges
179+
Finds groups of overlapping or adjacent maintenance windows for each machine.
180+
"""
181+
return (cb.from_facts(Maintenance)
182+
.group_by(
183+
lambda m: m.machine_id,
184+
Collectors.connected_ranges(
185+
start_func=lambda m: m.start_time,
186+
end_func=lambda m: m.end_time
187+
)
188+
)
189+
.filter(lambda machine_id, ranges: any(len(r.data) > 1 for r in ranges))
190+
.penalize_simple(lambda machine_id, ranges: 0) # Reporting only
191+
)
192+
193+
@cb.constraint("tumbling_window_events")
194+
def tumbling_window_example():
195+
"""Demonstrates: TumblingWindowCollector for aggregation
196+
Groups events into 10-second, non-overlapping ("tumbling") windows
197+
and calculates the average transaction value for each window.
198+
"""
199+
# Define a key function to map timestamps to a 10-second window start time
200+
def get_window_key(timestamp: datetime) -> datetime:
201+
epoch = datetime(1970, 1, 1, tzinfo=timezone.utc)
202+
window_size_sec = 10
203+
elapsed_sec = (timestamp - epoch).total_seconds()
204+
window_index = int(elapsed_sec // window_size_sec)
205+
window_start_ts = epoch.timestamp() + window_index * window_size_sec
206+
return datetime.fromtimestamp(window_start_ts, tz=timezone.utc)
207+
208+
return (cb.from_facts(UserEvent)
209+
.group_by(
210+
group_key_function=lambda e: get_window_key(e.timestamp),
211+
collector_supplier=Collectors.avg(lambda e: e.value)
212+
)
213+
.filter(lambda window_start, avg_value: avg_value > 0)
214+
.penalize_simple(lambda window_start, avg_value: 0) # Reporting only
215+
)
216+
217+
# 3. Main Execution Block
218+
# =======================
219+
220+
def run_demonstration():
221+
"""Builds the session, inserts data, and prints the results."""
222+
223+
# --- Sample Data ---
224+
now = datetime.now(timezone.utc)
225+
sales_data = [
226+
Sale("s1", "prod-a", "cust-1", 10.0, 1, now),
227+
Sale("s2", "prod-b", "cust-1", 25.5, 2, now + timedelta(seconds=1)),
228+
Sale("s3", "prod-a", "cust-2", 12.0, 5, now + timedelta(seconds=2)),
229+
Sale("s4", "prod-a", "cust-1", 11.5, 2, now + timedelta(seconds=3)),
230+
Sale("s5", "prod-b", "cust-3", 24.0, 1, now + timedelta(seconds=4)),
231+
Sale("s6", "prod-a", "cust-3", 12.5, 3, now + timedelta(seconds=5)),
232+
]
233+
234+
shipments_data = [
235+
Shipment("order-1", "sh-101", 1),
236+
Shipment("order-1", "sh-102", 2),
237+
Shipment("order-2", "sh-201", 1),
238+
Shipment("order-1", "sh-104", 4), # Gap in sequence
239+
Shipment("order-1", "sh-103", 3),
240+
]
241+
242+
maintenance_data = [
243+
Maintenance("m1", now, now + timedelta(hours=2)),
244+
Maintenance("m2", now, now + timedelta(hours=1)),
245+
Maintenance("m1", now + timedelta(hours=1), now + timedelta(hours=3)), # Overlaps with the first
246+
Maintenance("m1", now + timedelta(hours=4), now + timedelta(hours=5)), # Adjacent
247+
]
248+
249+
user_events_data = [
250+
UserEvent("u1", "tx", 100, now),
251+
UserEvent("u2", "tx", 150, now + timedelta(seconds=2)),
252+
UserEvent("u1", "tx", 50, now + timedelta(seconds=8)),
253+
UserEvent("u3", "tx", 200, now + timedelta(seconds=11)), # New window
254+
UserEvent("u2", "tx", 300, now + timedelta(seconds=15)),
255+
]
256+
257+
# --- Build and Run Session ---
258+
session = cb.build()
259+
260+
print("## [INITIAL STATE] Inserting all facts...")
261+
session.insert_batch(sales_data)
262+
session.insert_batch(shipments_data)
263+
session.insert_batch(maintenance_data)
264+
session.insert_batch(user_events_data)
265+
266+
matches = session.get_constraint_matches()
267+
print_results(matches)
268+
269+
# --- Demonstrate Retraction ---
270+
print("\n\n## [RETRACTION] Retracting one sale (s6) and one shipment (sh-103)...")
271+
sale_to_retract = sales_data[-1] # Sale("s6", "prod-a", ...)
272+
shipment_to_retract = shipments_data[-1] # Shipment("order-1", "sh-103", 3)
273+
274+
session.retract(sale_to_retract)
275+
session.retract(shipment_to_retract)
276+
277+
matches_after_retract = session.get_constraint_matches()
278+
print("## Results after retraction:")
279+
print_results(matches_after_retract)
280+
281+
def print_results(matches):
282+
"""Helper function to print constraint matches in a structured way."""
283+
if not matches:
284+
print(" No constraint matches found.")
285+
return
286+
287+
for constraint_id, match_list in matches.items():
288+
print(f"\n### Constraint: `{constraint_id}`")
289+
print("-" * (len(constraint_id) + 18))
290+
for score_obj, match_tuple in match_list:
291+
facts = [f for f in [
292+
getattr(match_tuple, 'fact_a', None),
293+
getattr(match_tuple, 'fact_b', None),
294+
] if f is not None]
295+
296+
print(f" - Match: {facts}")
297+
print(f" Score: {score_obj}")
298+
print("-" * (len(constraint_id) + 18))
299+
300+
301+
if __name__ == "__main__":
302+
run_demonstration()
303+

0 commit comments

Comments
 (0)