Skip to content

Commit feced0c

Browse files
committed
PythonJob: automatically serialize the inputs into AiiDA data (#85)
First search the AiiDA data in the entry point based on the module name and class name, for example, `ase.atoms.Atoms`, if find a entry point, use it to serialize the value, if not found, use `GeneralData` to seralize the value. Add more data entry point: int, float, str, bool, list and dict
1 parent d00e989 commit feced0c

File tree

11 files changed

+483
-259
lines changed

11 files changed

+483
-259
lines changed

aiida_workgraph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
from .decorator import node, build_node
44

55

6-
__version__ = "0.2.5"
6+
__version__ = "0.2.6"
77

88
__all__ = ["WorkGraph", "Node", "node", "build_node"]

aiida_workgraph/calculations/python.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
)
1919

2020

21-
from .general_data import GeneralData
22-
2321
__all__ = ("PythonJob",)
2422

2523

@@ -55,7 +53,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
5553
spec.input(
5654
"function_name", valid_type=Str, serializer=to_aiida_type, required=False
5755
)
58-
spec.input_namespace("kwargs", valid_type=Data, required=False)
56+
spec.input_namespace(
57+
"kwargs", valid_type=Data, required=False
58+
) # , serializer=general_serializer)
5959
spec.input(
6060
"output_name_list",
6161
valid_type=List,
@@ -188,12 +188,14 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
188188
# create pickle file for the inputs
189189
input_values = {}
190190
for key, value in inputs.items():
191-
if isinstance(value, GeneralData):
191+
if isinstance(value, Data) and hasattr(value, "value"):
192192
# get the value of the pickled data
193193
input_values[key] = value.value
194194
else:
195-
raise ValueError(f"Unsupported data type: {type(value)}")
196-
# save the value as a pickle file, the path is absolute
195+
raise ValueError(
196+
f"Input data {value} is not supported. Only AiiDA data Node with a value attribute is allowed. "
197+
)
198+
# save the value as a pickle file, the path is absolute
197199
filename = "inputs.pickle"
198200
with folder.open(filename, "wb") as handle:
199201
pickle.dump(input_values, handle)

aiida_workgraph/calculations/python_parser.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Parser for an `PythonJob` job."""
22
from aiida.parsers.parser import Parser
3-
from .general_data import GeneralData
3+
from aiida_workgraph.orm import general_serializer
44

55

66
class PythonParser(Parser):
@@ -14,18 +14,36 @@ def parse(self, **kwargs):
1414
with self.retrieved.base.repository.open("results.pickle", "rb") as handle:
1515
results = pickle.load(handle)
1616
output_name_list = self.node.inputs.output_name_list.get_list()
17+
# output_name_list exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved']
18+
output_name_list = [
19+
name
20+
for name in output_name_list
21+
if name
22+
not in [
23+
"_wait",
24+
"_outputs",
25+
"remote_folder",
26+
"remote_stash",
27+
"retrieved",
28+
]
29+
]
30+
outputs = {}
1731
if isinstance(results, tuple):
1832
if len(output_name_list) != len(results):
1933
raise ValueError(
2034
"The number of results does not match the number of output_name_list."
2135
)
2236
for i in range(len(output_name_list)):
23-
self.out(output_name_list[i].name, GeneralData(results[i]))
24-
elif isinstance(results, dict):
25-
for key, value in results.items():
26-
self.out(key, GeneralData(value))
37+
outputs[output_name_list[i].name] = results[i]
38+
outputs = general_serializer(outputs)
39+
elif isinstance(results, dict) and len(results) == len(
40+
output_name_list
41+
):
42+
outputs = general_serializer(results)
2743
else:
28-
self.out("result", GeneralData(results))
44+
outputs = general_serializer({"result": results})
45+
for key, value in outputs.items():
46+
self.out(key, value)
2947
except OSError:
3048
return self.exit_codes.ERROR_READING_OUTPUT_FILE
3149
except ValueError:

aiida_workgraph/engine/workgraph.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
825825
self.to_context(**{name: process})
826826
elif node["metadata"]["node_type"].upper() in ["PYTHONJOB"]:
827827
from aiida_workgraph.calculations.python import PythonJob
828-
from aiida_workgraph.calculations.general_data import GeneralData
828+
from aiida_workgraph.orm.serializer import general_serializer
829829
from aiida_workgraph.utils import get_or_create_code
830830

831831
print("node type: Python.")
@@ -849,20 +849,10 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
849849
# get the source code of the function
850850
function_name = executor.__name__
851851
function_source_code = node["executor"]["function_source_code"]
852-
inputs = {}
853-
# save all kwargs to inputs port
854-
for key, value in kwargs.items():
855-
if isinstance(value, orm.Node):
856-
if not hasattr(value, "value"):
857-
raise ValueError(
858-
"Only AiiDA data Node with a value attribute is allowed."
859-
)
860-
inputs[key] = value
861-
else:
862-
inputs[key] = GeneralData(value)
863852
# outputs
864853
output_name_list = [output["name"] for output in node["outputs"]]
865-
854+
# serialize the kwargs into AiiDA Data
855+
inputs = general_serializer(kwargs)
866856
# transfer the args to kwargs
867857
process = self.submit(
868858
PythonJob,

aiida_workgraph/orm/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .general_data import GeneralData
2+
from .serializer import general_serializer
3+
4+
__all__ = (
5+
"GeneralData",
6+
"general_serializer",
7+
)

aiida_workgraph/orm/atoms.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# -*- coding: utf-8 -*-
2+
###########################################################################
3+
# Copyright (c), The AiiDA team. All rights reserved. #
4+
# This file is part of the AiiDA code. #
5+
# #
6+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
7+
# For further information on the license, see the LICENSE.txt file #
8+
# For further information please visit http://www.aiida.net #
9+
###########################################################################
10+
"""`Data` sub class to represent a list."""
11+
12+
from aiida.orm import Data
13+
from ase import Atoms
14+
15+
__all__ = ("AtomsData",)
16+
17+
18+
class AtomsData(Data):
19+
"""`Data to represent a ASE Atoms."""
20+
21+
_cached_atoms = None
22+
23+
def __init__(self, value=None, **kwargs):
24+
"""Initialise a ``List`` node instance.
25+
26+
:param value: list to initialise the ``List`` node from
27+
"""
28+
data = value or kwargs.pop("atoms", Atoms())
29+
super().__init__(**kwargs)
30+
self.set_atoms(data)
31+
32+
@property
33+
def value(self):
34+
return self.get_atoms()
35+
36+
def initialize(self):
37+
super().initialize()
38+
self._cached_atoms = None
39+
40+
def __getitem__(self, item):
41+
return self.get_atoms()[item]
42+
43+
def __setitem__(self, key, value):
44+
data = self.get_atoms()
45+
data[key] = value
46+
if not self._using_atoms_reference():
47+
self.set_atoms(data)
48+
49+
def __delitem__(self, key):
50+
data = self.get_atoms()
51+
del data[key]
52+
if not self._using_atoms_reference():
53+
self.set_atoms(data)
54+
55+
def __len__(self):
56+
return len(self.get_atoms())
57+
58+
def __str__(self):
59+
return f"{super().__str__()} : {self.get_atoms()}"
60+
61+
def __eq__(self, other):
62+
if isinstance(other, Atoms):
63+
return self.get_atoms() == other.get_atoms()
64+
return self.get_atoms() == other
65+
66+
def append(self, value):
67+
data = self.get_atoms()
68+
data.append(value)
69+
if not self._using_atoms_reference():
70+
self.set_atoms(data)
71+
72+
def extend(self, value): # pylint: disable=arguments-renamed
73+
data = self.get_atoms()
74+
data.extend(value)
75+
if not self._using_atoms_reference():
76+
self.set_atoms(data)
77+
78+
def get_atoms(self):
79+
"""Return the contents of this node.
80+
81+
:return: a Atoms
82+
"""
83+
import pickle
84+
85+
def get_atoms_from_file(self):
86+
filename = "atoms.pkl"
87+
# Open a handle in binary read mode as the arrays are written as binary files as well
88+
with self.base.repository.open(filename, mode="rb") as f:
89+
return pickle.loads(f.read()) # pylint: disable=unexpected-keyword-arg
90+
91+
# Return with proper caching if the node is stored, otherwise always re-read from disk
92+
if not self.is_stored:
93+
return get_atoms_from_file(self)
94+
95+
if self._cached_atoms is None:
96+
self._cached_atoms = get_atoms_from_file(self)
97+
98+
return self._cached_atoms
99+
100+
def set_atoms(self, atoms):
101+
"""Set the contents of this node.
102+
103+
:param atoms: the atoms to set
104+
"""
105+
import pickle
106+
107+
if not isinstance(atoms, Atoms):
108+
raise TypeError("Must supply Atoms type")
109+
self.base.repository.put_object_from_bytes(pickle.dumps(atoms), "atoms.pkl")
110+
formula = atoms.get_chemical_formula()
111+
# Store the array name and shape for querying purposes
112+
self.base.attributes.set("formula", formula)
113+
114+
def _using_atoms_reference(self):
115+
"""
116+
This function tells the class if we are using a list reference. This
117+
means that calls to self.get_atoms return a reference rather than a copy
118+
of the underlying list and therefore self.set_atoms need not be called.
119+
This knwoledge is essential to make sure this class is performant.
120+
121+
Currently the implementation assumes that if the node needs to be
122+
stored then it is using the attributes cache which is a reference.
123+
124+
:return: True if using self.get_atoms returns a reference to the
125+
underlying sequence. False otherwise.
126+
:rtype: bool
127+
"""
128+
return self.is_stored

aiida_workgraph/calculations/general_data.py renamed to aiida_workgraph/orm/general_data.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
"""`Data` sub class to represent any data using pickle."""
22

3-
from aiida.orm import Data
3+
from aiida import orm
44

5-
__all__ = ("GeneralData",)
5+
6+
class Dict(orm.Dict):
7+
@property
8+
def value(self):
9+
return self.get_dict()
10+
11+
12+
class List(orm.List):
13+
@property
14+
def value(self):
15+
return self.get_list()
616

717

8-
class GeneralData(Data):
18+
class GeneralData(orm.Data):
919
"""`Data to represent a pickled value."""
1020

1121
def __init__(self, value=None, **kwargs):

aiida_workgraph/orm/serializer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from .general_data import GeneralData
2+
from aiida import orm
3+
from importlib.metadata import entry_points
4+
5+
6+
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
7+
eps = {ep.name: ep for ep in entry_points().get("aiida.data", [])}
8+
9+
10+
def general_serializer(inputs):
11+
"""Serialize the inputs to a dictionary of AiiDA data nodes.
12+
13+
Args:
14+
inputs (dict): The inputs to be serialized.
15+
16+
Returns:
17+
dict: The serialized inputs.
18+
"""
19+
new_inputs = {}
20+
# save all kwargs to inputs port
21+
for key, value in inputs.items():
22+
if isinstance(value, orm.Data):
23+
if not hasattr(value, "value"):
24+
raise ValueError(
25+
"Only AiiDA data Node with a value attribute is allowed."
26+
)
27+
new_inputs[key] = value
28+
# if value is a class instance, get its __module__ and class name as a string
29+
# for example, an Atoms will have ase.atoms.Atoms
30+
else:
31+
# try to get the serializer from the entry points
32+
value_type = type(value)
33+
ep_key = f"{value_type.__module__}.{value_type.__name__}"
34+
# search for the key in the entry points
35+
if ep_key in eps:
36+
try:
37+
new_inputs[key] = eps[ep_key].load()(value)
38+
except Exception as e:
39+
raise ValueError(f"Error in serializing {key}: {e}")
40+
else:
41+
# try to serialize the value as a GeneralData
42+
try:
43+
new_inputs[key] = GeneralData(value)
44+
except Exception as e:
45+
raise ValueError(f"Error in serializing {key}: {e}")
46+
47+
return new_inputs

0 commit comments

Comments
 (0)