Skip to content

Commit b4563ef

Browse files
Switch take to get
1 parent d50fb77 commit b4563ef

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

python/egg_smol/bindings.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class EGraph:
1111
def __init__(self, fact_directory: str | Path | None = None, seminaive=True) -> None: ...
1212
def parse_program(self, __input: str, /) -> list[_Command]: ...
1313
def run_program(self, *commands: _Command) -> list[str]: ...
14-
def take_extract_report(self) -> Optional[ExtractReport]: ...
15-
def take_run_report(self) -> Optional[RunReport]: ...
14+
def extract_report(self) -> Optional[ExtractReport]: ...
15+
def run_report(self) -> Optional[RunReport]: ...
1616

1717
@final
1818
class EggSmolError(Exception):

python/egg_smol/egraph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _simplify(self, expr: EXPR, limit: int, ruleset: Optional[Ruleset], until: t
126126
tp, decl = expr_parts(expr)
127127
egg_expr = decl.to_egg(self._decls)
128128
self._run_program([bindings.Simplify(egg_expr, Config(limit, ruleset, until)._to_egg_config(self._decls))])
129-
extract_report = self._get_egraph().take_extract_report()
129+
extract_report = self._get_egraph().extract_report()
130130
if not extract_report:
131131
raise ValueError("No extract report saved")
132132
new_tp, new_decl = tp_and_expr_decl_from_egg(self._decls, extract_report.expr)
@@ -181,7 +181,7 @@ def run(self, limit_or_schedule: int | Schedule, /, *until: Fact) -> bindings.Ru
181181

182182
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
183183
self._run_program([bindings.RunScheduleCommand(schedule._to_egg(self._decls))])
184-
run_report = self._get_egraph().take_run_report()
184+
run_report = self._get_egraph().run_report()
185185
if not run_report:
186186
raise ValueError("No run report saved")
187187
return run_report
@@ -226,7 +226,7 @@ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]:
226226

227227
def _run_extract(self, expr: bindings._Expr, n: int) -> bindings.ExtractReport:
228228
self._run_program([bindings.Extract(n, expr)])
229-
extract_report = self._get_egraph().take_extract_report()
229+
extract_report = self._get_egraph().extract_report()
230230
if not extract_report:
231231
raise ValueError("No extract report saved")
232232
return extract_report

python/tests/test_bindings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_run_rules(self):
111111
)
112112
end_time = datetime.datetime.now()
113113

114-
run_report = egraph.take_run_report()
114+
run_report = egraph.run_report()
115115
assert isinstance(run_report, RunReport)
116116
total_time = run_report.search_time + run_report.apply_time + run_report.rebuild_time
117117
# Verify less than the total time (which includes time spent in Python).
@@ -126,15 +126,15 @@ def test_extract(self):
126126
Define("y", Call("Num", [Lit(Int(2))]), 1),
127127
Extract(0, Var("x")),
128128
)
129-
assert egraph.take_extract_report() == ExtractReport(6, Call("Num", [Lit(Int(1))]), [])
129+
assert egraph.extract_report() == ExtractReport(6, Call("Num", [Lit(Int(1))]), [])
130130
egraph.run_program(Extract(0, Var("y")))
131131
pytest.xfail(reason="https://github.com/mwillsey/egg-smol/issues/128")
132-
assert egraph.take_extract_report() == ExtractReport(1, Call("y", []), [])
132+
assert egraph.extract_report() == ExtractReport(1, Call("y", []), [])
133133

134134
def test_extract_string(self):
135135
egraph = EGraph()
136136
egraph.run_program(Define("x", Lit(String("hello")), None), Extract(0, Var("x")))
137-
assert egraph.take_extract_report() == ExtractReport(0, Lit(String("hello")), [])
137+
assert egraph.extract_report() == ExtractReport(0, Lit(String("hello")), [])
138138

139139
def test_sort_alias(self):
140140
# From map example
@@ -149,7 +149,7 @@ def test_sort_alias(self):
149149
Check([Eq([Lit(String("one")), Call("get", [Var("my_map1"), Lit(Int(1))])])]),
150150
Extract(0, Var("my_map2")),
151151
)
152-
assert egraph.take_extract_report() == ExtractReport(
152+
assert egraph.extract_report() == ExtractReport(
153153
0,
154154
Call(
155155
"insert",

src/egraph.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,25 @@ impl EGraph {
5050
Ok(res)
5151
}
5252

53-
/// Takes the last expressions extracted from the EGraph, if the last command
53+
/// Gets the last expressions extracted from the EGraph, if the last command
5454
/// was a Simplify or Extract command.
5555
#[pyo3(signature = ())]
56-
fn take_extract_report(&mut self) -> Option<ExtractReport> {
57-
info!("Taking last extract report");
58-
self.egraph.take_extract_report().map(|r| r.into())
56+
fn extract_report(&mut self) -> Option<ExtractReport> {
57+
info!("Getting last extract report");
58+
match self.egraph.get_extract_report() {
59+
Some(report) => Some(report.into()),
60+
None => None,
61+
}
5962
}
6063

61-
/// Takes the last run report from the EGraph, if the last command
64+
/// Gets the last run report from the EGraph, if the last command
6265
/// was a run or simplify command.
6366
#[pyo3(signature = ())]
64-
fn take_run_report(&mut self) -> Option<RunReport> {
65-
info!("Taking last run report");
66-
self.egraph.take_run_report().map(|r| r.into())
67+
fn run_report(&mut self) -> Option<RunReport> {
68+
info!("Getting last run report");
69+
match self.egraph.get_run_report() {
70+
Some(report) => Some(report.into()),
71+
None => None,
72+
}
6773
}
6874
}

0 commit comments

Comments
 (0)