Skip to content

Commit 64ea4d0

Browse files
authored
Refactor the syntax of decorator parameters for inputs, outputs and properties. (#123)
Refactor the syntax of decorator parameters for inputs, outputs and properties. Using a dictionary with keys and values to replace the list approach.
1 parent f10596a commit 64ea4d0

File tree

13 files changed

+218
-161
lines changed

13 files changed

+218
-161
lines changed

aiida_workgraph/collection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def new(
1818
) -> Any:
1919
from aiida_workgraph.decorator import (
2020
build_task_from_callable,
21-
build_python_task,
22-
build_shell_task,
21+
build_pythonjob_task,
22+
build_shelljob_task,
2323
)
2424

2525
# build the task on the fly if the identifier is a callable
@@ -31,13 +31,13 @@ def new(
3131
"GraphBuilder task cannot be run remotely. Please set run_remotely=False."
3232
)
3333
# this is a PythonJob
34-
identifier, _ = build_python_task(identifier)
34+
identifier, _ = build_pythonjob_task(identifier)
3535
return super().new(identifier, name, uuid, **kwargs)
3636
if isinstance(identifier, str) and identifier.upper() == "PythonJob":
37-
identifier, _ = build_python_task(kwargs.pop("function"))
37+
identifier, _ = build_pythonjob_task(kwargs.pop("function"))
3838
return super().new(identifier, name, uuid, **kwargs)
3939
if isinstance(identifier, str) and identifier.upper() == "SHELLJOB":
40-
identifier, _, links = build_shell_task(
40+
identifier, _, links = build_shelljob_task(
4141
nodes=kwargs.get("nodes", {}),
4242
outputs=kwargs.get("outputs", None),
4343
parser_outputs=kwargs.pop("parser_outputs", None),

aiida_workgraph/decorator.py

Lines changed: 77 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
2-
from aiida_workgraph.utils import get_executor
2+
from aiida_workgraph.utils import get_executor, serialize_function
33
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain
44
from aiida import orm
55
from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode
@@ -45,14 +45,18 @@ def add_input_recursive(
4545
else:
4646
port_name = f"{prefix}.{port.name}"
4747
required = port.required and required
48-
input_names = [input[1] for input in inputs]
48+
input_names = [input["name"] for input in inputs]
4949
if isinstance(port, PortNamespace):
5050
# TODO the default value is {} could cause problem, because the address of the dict is the same,
5151
# so if you change the value of one port, the value of all the ports of other tasks will be changed
5252
# consider to use None as default value
5353
if port_name not in input_names:
5454
inputs.append(
55-
["General", port_name, {"property": ["General", {"default": {}}]}]
55+
{
56+
"identifier": "General",
57+
"name": port_name,
58+
"property": {"identifier": "General", "default": {}},
59+
}
5660
)
5761
if required:
5862
args.append(port_name)
@@ -72,7 +76,7 @@ def add_input_recursive(
7276
socket_type = aiida_socket_maping.get(port.valid_type[0], "General")
7377
else:
7478
socket_type = aiida_socket_maping.get(port.valid_type, "General")
75-
inputs.append([socket_type, port_name])
79+
inputs.append({"identifier": socket_type, "name": port_name})
7680
if required:
7781
args.append(port_name)
7882
else:
@@ -92,18 +96,18 @@ def add_output_recursive(
9296
else:
9397
port_name = f"{prefix}.{port.name}"
9498
required = port.required and required
95-
output_names = [output[1] for output in outputs]
99+
output_names = [output["name"] for output in outputs]
96100
if isinstance(port, PortNamespace):
97101
# TODO the default value is {} could cause problem, because the address of the dict is the same,
98102
# so if you change the value of one port, the value of all the ports of other tasks will be changed
99103
# consider to use None as default value
100104
if port_name not in output_names:
101-
outputs.append(["General", port_name])
105+
outputs.append({"identifier": "General", "name": port_name})
102106
for value in port.values():
103107
add_output_recursive(outputs, value, prefix=port_name, required=required)
104108
else:
105109
if port_name not in output_names:
106-
outputs.append(["General", port_name])
110+
outputs.append({"identifier": "General", "name": port_name})
107111
return outputs
108112

109113

@@ -213,35 +217,45 @@ def build_task_from_AiiDA(
213217
or executor.process_class._var_positional
214218
)
215219
tdata["var_kwargs"] = name
216-
inputs.append(["General", name, {"property": ["General", {"default": {}}]}])
220+
inputs.append(
221+
{
222+
"identifier": "General",
223+
"name": name,
224+
"property": {"identifier": "General", "default": {}},
225+
}
226+
)
227+
# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
228+
# so I pickled the function here, but this is not necessary
229+
# we need to update the node_graph to support the path and name of the function
230+
tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__)
231+
tdata["executor"] = {
232+
"executor": pickle.dumps(executor),
233+
"type": tdata["task_type"],
234+
"is_pickle": True,
235+
}
217236
if tdata["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]:
218-
outputs = [["General", "result"]] if not outputs else outputs
237+
outputs = (
238+
[{"identifier": "General", "name": "result"}] if not outputs else outputs
239+
)
240+
# get the source code of the function
241+
tdata["executor"] = serialize_function(executor)
242+
# tdata["executor"]["type"] = tdata["task_type"]
219243
# print("kwargs: ", kwargs)
220244
# add built-in sockets
221-
outputs.append(["General", "_outputs"])
222-
outputs.append(["General", "_wait"])
223-
inputs.append(["General", "_wait", {"link_limit": 1e6}])
245+
outputs.append({"identifier": "General", "name": "_outputs"})
246+
outputs.append({"identifier": "General", "name": "_wait"})
247+
inputs.append({"identifier": "General", "name": "_wait", "link_limit": 1e6})
224248
tdata["node_class"] = Task
225249
tdata["args"] = args
226250
tdata["kwargs"] = kwargs
227251
tdata["inputs"] = inputs
228252
tdata["outputs"] = outputs
229-
tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__)
230-
# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
231-
# so I pickled the function here, but this is not necessary
232-
# we need to update the node_graph to support the path and name of the function
233-
executor = {
234-
"executor": pickle.dumps(executor),
235-
"type": tdata["task_type"],
236-
"is_pickle": True,
237-
}
238-
tdata["executor"] = executor
239253
task = create_task(tdata)
240254
task.is_aiida_component = True
241255
return task, tdata
242256

243257

244-
def build_python_task(func: Callable) -> Task:
258+
def build_pythonjob_task(func: Callable) -> Task:
245259
"""Build PythonJob task from function."""
246260
from aiida_workgraph.calculations.python import PythonJob
247261
from copy import deepcopy
@@ -254,10 +268,10 @@ def build_python_task(func: Callable) -> Task:
254268
inputs = tdata["inputs"]
255269
inputs.extend(
256270
[
257-
["String", "computer"],
258-
["String", "code_label"],
259-
["String", "code_path"],
260-
["String", "prepend_text"],
271+
{"identifier": "String", "name": "computer"},
272+
{"identifier": "String", "name": "code_label"},
273+
{"identifier": "String", "name": "code_path"},
274+
{"identifier": "String", "name": "prepend_text"},
261275
]
262276
)
263277
outputs = tdata["outputs"]
@@ -269,8 +283,8 @@ def build_python_task(func: Callable) -> Task:
269283
outputs.append(output)
270284
# change "copy_files" link_limit to 1e6
271285
for input in inputs:
272-
if input[1] == "copy_files":
273-
input[2].update({"link_limit": 1e6})
286+
if input["name"] == "copy_files":
287+
input["link_limit"] = 1e6
274288
# append the kwargs of the PythonJob task to the function task
275289
kwargs = tdata["kwargs"]
276290
kwargs.extend(["computer", "code_label", "code_path", "prepend_text"])
@@ -284,7 +298,7 @@ def build_python_task(func: Callable) -> Task:
284298
return task, tdata
285299

286300

287-
def build_shell_task(
301+
def build_shelljob_task(
288302
nodes: dict = None, outputs: list = None, parser_outputs: list = None
289303
) -> Task:
290304
"""Build ShellJob with custom inputs and outputs."""
@@ -300,7 +314,7 @@ def build_shell_task(
300314
nodes = {} if nodes is None else nodes
301315
keys = list(nodes.keys())
302316
for key in keys:
303-
inputs.append(["General", f"nodes.{key}"])
317+
inputs.append({"identifier": "General", "name": f"nodes.{key}"})
304318
# input is a output of another task, we make a link
305319
if isinstance(nodes[key], NodeSocket):
306320
links[f"nodes.{key}"] = nodes[key]
@@ -309,12 +323,20 @@ def build_shell_task(
309323
for input in inputs:
310324
if input not in tdata["inputs"]:
311325
tdata["inputs"].append(input)
312-
tdata["kwargs"].append(input[1])
326+
tdata["kwargs"].append(input["name"])
313327
# Extend the outputs
314-
tdata["outputs"].extend([["General", "stdout"], ["General", "stderr"]])
328+
tdata["outputs"].extend(
329+
[
330+
{"identifier": "General", "name": "stdout"},
331+
{"identifier": "General", "name": "stderr"},
332+
]
333+
)
315334
outputs = [] if outputs is None else outputs
316335
parser_outputs = [] if parser_outputs is None else parser_outputs
317-
outputs = [["General", ShellParser.format_link_label(output)] for output in outputs]
336+
outputs = [
337+
{"identifier": "General", "name": ShellParser.format_link_label(output)}
338+
for output in outputs
339+
]
318340
outputs.extend(parser_outputs)
319341
# add user defined outputs
320342
for output in outputs:
@@ -324,8 +346,8 @@ def build_shell_task(
324346
tdata["identifier"] = "ShellJob"
325347
tdata["inputs"].extend(
326348
[
327-
["General", "command"],
328-
["General", "resolve_command"],
349+
{"identifier": "General", "name": "command"},
350+
{"identifier": "General", "name": "resolve_command"},
329351
]
330352
)
331353
tdata["kwargs"].extend(["command", "resolve_command"])
@@ -346,25 +368,29 @@ def build_task_from_workgraph(wg: any) -> Task:
346368
# add all the inputs/outputs from the tasks in the workgraph
347369
for task in wg.tasks:
348370
# inputs
349-
inputs.append(["General", f"{task.name}"])
371+
inputs.append({"identifier": "General", "name": f"{task.name}"})
350372
for socket in task.inputs:
351373
if socket.name == "_wait":
352374
continue
353-
inputs.append(["General", f"{task.name}.{socket.name}"])
375+
inputs.append(
376+
{"identifier": "General", "name": f"{task.name}.{socket.name}"}
377+
)
354378
# outputs
355-
outputs.append(["General", f"{task.name}"])
379+
outputs.append({"identifier": "General", "name": f"{task.name}"})
356380
for socket in task.outputs:
357381
if socket.name in ["_wait", "_outputs"]:
358382
continue
359-
outputs.append(["General", f"{task.name}.{socket.name}"])
383+
outputs.append(
384+
{"identifier": "General", "name": f"{task.name}.{socket.name}"}
385+
)
360386
group_outputs.append(
361387
[f"{task.name}.{socket.name}", f"{task.name}.{socket.name}"]
362388
)
363-
kwargs = [input[1] for input in inputs]
389+
kwargs = [input["name"] for input in inputs]
364390
# add built-in sockets
365-
outputs.append(["General", "_outputs"])
366-
outputs.append(["General", "_wait"])
367-
inputs.append(["General", "_wait", {"link_limit": 1e6}])
391+
outputs.append({"identifier": "General", "name": "_outputs"})
392+
outputs.append({"identifier": "General", "name": "_wait"})
393+
inputs.append({"identifier": "General", "name": "_wait", "link_limit": 1e6})
368394
tdata["node_class"] = Task
369395
tdata["kwargs"] = kwargs
370396
tdata["inputs"] = inputs
@@ -385,77 +411,6 @@ def build_task_from_workgraph(wg: any) -> Task:
385411
return task
386412

387413

388-
def get_required_imports(func):
389-
"""Retrieve type hints and the corresponding module"""
390-
from typing import get_type_hints, _SpecialForm
391-
392-
type_hints = get_type_hints(func)
393-
imports = {}
394-
395-
def add_imports(type_hint):
396-
if isinstance(
397-
type_hint, _SpecialForm
398-
): # Handle special forms like Any, Union, Optional
399-
module_name = "typing"
400-
type_name = type_hint._name or str(type_hint)
401-
elif hasattr(
402-
type_hint, "__origin__"
403-
): # This checks for higher-order types like List, Dict
404-
module_name = type_hint.__module__
405-
type_name = type_hint._name
406-
for arg in type_hint.__args__:
407-
if arg is type(None): # noqa: E721
408-
continue
409-
add_imports(arg) # Recursively add imports for each argument
410-
elif hasattr(type_hint, "__module__"):
411-
module_name = type_hint.__module__
412-
type_name = type_hint.__name__
413-
else:
414-
return # If no module or origin, we can't import it, e.g., for literals
415-
416-
if module_name not in imports:
417-
imports[module_name] = set()
418-
imports[module_name].add(type_name)
419-
420-
for _, type_hint in type_hints.items():
421-
add_imports(type_hint)
422-
423-
return imports
424-
425-
426-
def serialize_function(func: Callable) -> Dict[str, Any]:
427-
"""Serialize a function for storage or transmission."""
428-
import cloudpickle as pickle
429-
import inspect
430-
import textwrap
431-
432-
source_code = inspect.getsource(func)
433-
source_code_lines = source_code.split("\n")
434-
# we need save the source code explicitly, because in the case of jupyter notebook,
435-
# the source code is not saved in the pickle file
436-
function_source_code = "\n".join(source_code_lines[1:])
437-
function_source_code = textwrap.dedent(function_source_code)
438-
# we also need to include the necessary imports for the types used in the type hints.
439-
try:
440-
required_imports = get_required_imports(func)
441-
except Exception as e:
442-
required_imports = {}
443-
print(f"Failed to get required imports for function {func.__name__}: {e}")
444-
# Generate import statements
445-
import_statements = "\n".join(
446-
f"from {module} import {', '.join(types)}"
447-
for module, types in required_imports.items()
448-
)
449-
return {
450-
"executor": pickle.dumps(func),
451-
"type": "function",
452-
"is_pickle": True,
453-
"function_name": func.__name__,
454-
"function_source_code": function_source_code,
455-
"import_statements": import_statements,
456-
}
457-
458-
459414
def generate_tdata(
460415
func: Callable,
461416
identifier: str,
@@ -475,9 +430,9 @@ def generate_tdata(
475430
)
476431
task_outputs = outputs
477432
# add built-in sockets
478-
_inputs.append(["General", "_wait", {"link_limit": 1e6}])
479-
task_outputs.append(["General", "_wait"])
480-
task_outputs.append(["General", "_outputs"])
433+
_inputs.append({"identifier": "General", "name": "_wait", "link_limit": 1e6})
434+
task_outputs.append({"identifier": "General", "name": "_wait"})
435+
task_outputs.append({"identifier": "General", "name": "_outputs"})
481436
tdata = {
482437
"node_class": Task,
483438
"identifier": identifier,
@@ -536,7 +491,7 @@ def decorator(func):
536491
func,
537492
identifier,
538493
inputs or [],
539-
outputs or [["General", "result"]],
494+
outputs or [{"identifier": "General", "name": "result"}],
540495
properties or [],
541496
catalog,
542497
task_type,
@@ -577,7 +532,9 @@ def decorator(func):
577532
# use cloudpickle to serialize function
578533
func.identifier = identifier
579534

580-
task_outputs = [["General", output[1]] for output in outputs]
535+
task_outputs = [
536+
{"identifier": "General", "name": output[1]} for output in outputs
537+
]
581538
# print(task_inputs, task_outputs)
582539
#
583540
task_type = "graph_builder"

aiida_workgraph/engine/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
8686
function_source_code = (
8787
task["executor"]["import_statements"]
8888
+ "\n"
89-
+ task["executor"]["function_source_code"]
89+
+ task["executor"]["function_source_code_without_decorator"]
9090
)
9191
# outputs
9292
output_name_list = [output["name"] for output in task["outputs"]]

0 commit comments

Comments
 (0)