Skip to content

Commit 52444e5

Browse files
authored
Merge pull request #130 from morxa/fix-function-variable-types
Use correctly typed variables when parsing functions
2 parents 61ce033 + c42ecb8 commit 52444e5

File tree

3 files changed

+71
-15
lines changed

3 files changed

+71
-15
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
pyver=`echo ${{ matrix.python-version }} | tr -d "."`
3030
tox -e py${{ matrix.python-version }}
3131
- name: Upload coverage to Codecov
32-
uses: codecov/codecov-action@v1
32+
uses: codecov/codecov-action@v5
3333
with:
3434
token: ${{ secrets.CODECOV_TOKEN }}
3535
file: ./coverage.xml

pddl/parser/domain.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -306,27 +306,24 @@ def num_effect(self, args):
306306
else:
307307
raise PDDLParsingError(f"Unrecognized assign operator: {args[1]}")
308308

309+
def _constant_or_variable(self, t):
310+
"""Get the constant or variable with the given name."""
311+
# Case where the term is a free variable (bug) or comes from a parent quantifier
312+
if not isinstance(t, Constant) and t not in self._current_parameters_by_name:
313+
return Variable(str(t), {})
314+
return t if isinstance(t, Constant) else self._current_parameters_by_name[t]
315+
309316
def atomic_formula_term(self, args):
310317
"""Process the 'atomic_formula_term' rule."""
311-
312-
def constant_or_variable(t):
313-
# Case where the term is a free variable (bug) or comes from a parent quantifier
314-
if (
315-
not isinstance(t, Constant)
316-
and t not in self._current_parameters_by_name
317-
):
318-
return Variable(str(t), {})
319-
return t if isinstance(t, Constant) else self._current_parameters_by_name[t]
320-
321318
if args[1] == Symbols.EQUAL.value:
322319
if not bool({Requirements.EQUALITY} & self._extended_requirements):
323320
raise PDDLMissingRequirementError(Requirements.EQUALITY)
324-
left = constant_or_variable(args[2])
325-
right = constant_or_variable(args[3])
321+
left = self._constant_or_variable(args[2])
322+
right = self._constant_or_variable(args[3])
326323
return EqualTo(left, right)
327324
else:
328325
predicate_name = args[1]
329-
terms = list(map(constant_or_variable, args[2:-1]))
326+
terms = list(map(self._constant_or_variable, args[2:-1]))
330327
return Predicate(predicate_name, *terms)
331328

332329
def constant(self, args):
@@ -387,7 +384,7 @@ def f_head(self, args):
387384
if len(args) == 1:
388385
return NumericFunction(args[0])
389386
function_name = args[1]
390-
variables = [Variable(x, {}) for x in args[2:-1]]
387+
variables = list(map(self._constant_or_variable, args[2:-1]))
391388
return NumericFunction(function_name, *variables)
392389

393390
def typed_list_name(self, args) -> Dict[name, Optional[name]]:

tests/test_parser/test_domain.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import lark
1717
import pytest
1818

19+
from pddl.logic.functions import BinaryFunction, Increase, NumericFunction
20+
from pddl.logic.predicates import Predicate
21+
from pddl.logic.terms import Variable
1922
from pddl.parser.domain import DomainParser
2023
from pddl.parser.symbols import Symbols
2124
from tests.conftest import TEXT_SYMBOLS
@@ -317,3 +320,59 @@ def test_check_action_costs_requirement_with_total_cost() -> None:
317320
match=r"action costs requirement is not specified, but the total-cost function is specified.",
318321
):
319322
DomainParser()(domain_str)
323+
324+
325+
def test_variable_types_in_strips_action_definition() -> None:
326+
"""Check typing for predicate variables in action preconditions and effects."""
327+
domain_str = dedent(
328+
"""
329+
(define (domain test)
330+
(:requirements :typing)
331+
(:types t1 t2)
332+
(:predicates (p ?x - t1 ?y - t2))
333+
(:action a
334+
:parameters (?x - t1 ?y - t2)
335+
:precondition (p ?x ?y)
336+
:effect (p ?x ?y)
337+
)
338+
)
339+
"""
340+
)
341+
domain = DomainParser()(domain_str)
342+
action = next(iter(domain.actions))
343+
x = Variable("x", {"t1"})
344+
y = Variable("y", {"t2"})
345+
assert action.parameters == (x, y)
346+
assert isinstance(action.precondition, Predicate)
347+
assert action.precondition.terms == (x, y)
348+
assert isinstance(action.effect, Predicate)
349+
assert action.effect.terms == (x, y)
350+
351+
352+
def test_variable_types_in_numeric_action_definition() -> None:
353+
"""Check typing for function variables in action preconditions and effects."""
354+
domain_str = dedent(
355+
"""
356+
(define (domain test)
357+
(:requirements :typing :numeric-fluents)
358+
(:types t1 t2)
359+
(:functions (f ?x - t1 ?y - t2))
360+
(:action a
361+
:parameters (?x - t1 ?y - t2)
362+
:precondition (<= 1 (f ?x ?y))
363+
:effect (increase (f ?x ?y) 1)
364+
)
365+
)
366+
"""
367+
)
368+
domain = DomainParser()(domain_str)
369+
action = next(iter(domain.actions))
370+
x = Variable("x", {"t1"})
371+
y = Variable("y", {"t2"})
372+
assert action.parameters == (x, y)
373+
assert isinstance(action.precondition, BinaryFunction)
374+
assert isinstance(action.precondition.operands[1], NumericFunction)
375+
assert action.precondition.operands[1].terms == (x, y)
376+
assert isinstance(action.effect, Increase)
377+
assert isinstance(action.effect.operands[0], NumericFunction)
378+
assert action.effect.operands[0].terms == (x, y)

0 commit comments

Comments
 (0)