Skip to content

Commit 8cf7b81

Browse files
authored
Decorator allow users to define the inputs manually for dynamic input (#281)
1 parent 1e0c64f commit 8cf7b81

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

aiida_workgraph/decorator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def build_task_from_AiiDA(
207207
spec = executor.spec()
208208
args = []
209209
kwargs = []
210+
user_defined_input_names = [input["name"] for input in inputs]
210211
for _key, port in spec.inputs.ports.items():
211212
add_input_recursive(inputs, port, args, kwargs, required=port.required)
212213
for _key, port in spec.outputs.ports.items():
@@ -230,6 +231,14 @@ def build_task_from_AiiDA(
230231
"property": {"identifier": "workgraph.any", "default": {}},
231232
}
232233
)
234+
# When the input is dyanmic, if user defines some input names does not included in the args and kwargs,
235+
# which means the user define the input names manually, we must add them to the kwargs
236+
for key in user_defined_input_names:
237+
if key not in args and key not in kwargs:
238+
if key == name:
239+
continue
240+
kwargs.append(key)
241+
233242
# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
234243
# so I pickled the function here, but this is not necessary
235244
# we need to update the node_graph to support the path and name of the function

tests/test_calcfunction.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
2-
from aiida_workgraph import WorkGraph
2+
from aiida_workgraph import WorkGraph, task
3+
from aiida import orm
34

45

56
def test_run(wg_calcfunction: WorkGraph) -> None:
@@ -14,10 +15,16 @@ def test_run(wg_calcfunction: WorkGraph) -> None:
1415

1516

1617
@pytest.mark.usefixtures("started_daemon_client")
17-
def test_submit(wg_calcfunction: WorkGraph) -> None:
18-
"""Submit simple calcfunction."""
19-
wg = wg_calcfunction
20-
wg.name = "test_submit_calcfunction"
21-
wg.submit(wait=True)
22-
# print("results: ", results[])
23-
assert wg.tasks["sumdiff2"].outputs["sum"].value == 9
18+
def test_dynamic_inputs() -> None:
19+
"""Test dynamic inputs.
20+
For dynamic inputs, we allow the user to define the inputs manually.
21+
"""
22+
23+
@task.calcfunction(inputs=[{"name": "x"}, {"name": "y"}])
24+
def add(**kwargs):
25+
return kwargs["x"] + kwargs["y"]
26+
27+
wg = WorkGraph("test_dynamic_inputs")
28+
wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2))
29+
wg.run()
30+
assert wg.tasks["add1"].outputs["result"].value == 3

0 commit comments

Comments
 (0)