1
1
from typing import Any , Callable , Dict , List , Optional , Union , Tuple
2
- from aiida_workgraph .utils import get_executor
2
+ from aiida_workgraph .utils import get_executor , serialize_function
3
3
from aiida .engine import calcfunction , workfunction , CalcJob , WorkChain
4
4
from aiida import orm
5
5
from aiida .orm .nodes .process .calculation .calcfunction import CalcFunctionNode
@@ -45,14 +45,18 @@ def add_input_recursive(
45
45
else :
46
46
port_name = f"{ prefix } .{ port .name } "
47
47
required = port .required and required
48
- input_names = [input [1 ] for input in inputs ]
48
+ input_names = [input ["name" ] for input in inputs ]
49
49
if isinstance (port , PortNamespace ):
50
50
# TODO the default value is {} could cause problem, because the address of the dict is the same,
51
51
# so if you change the value of one port, the value of all the ports of other tasks will be changed
52
52
# consider to use None as default value
53
53
if port_name not in input_names :
54
54
inputs .append (
55
- ["General" , port_name , {"property" : ["General" , {"default" : {}}]}]
55
+ {
56
+ "identifier" : "General" ,
57
+ "name" : port_name ,
58
+ "property" : {"identifier" : "General" , "default" : {}},
59
+ }
56
60
)
57
61
if required :
58
62
args .append (port_name )
@@ -72,7 +76,7 @@ def add_input_recursive(
72
76
socket_type = aiida_socket_maping .get (port .valid_type [0 ], "General" )
73
77
else :
74
78
socket_type = aiida_socket_maping .get (port .valid_type , "General" )
75
- inputs .append ([ socket_type , port_name ] )
79
+ inputs .append ({ "identifier" : socket_type , "name" : port_name } )
76
80
if required :
77
81
args .append (port_name )
78
82
else :
@@ -92,18 +96,18 @@ def add_output_recursive(
92
96
else :
93
97
port_name = f"{ prefix } .{ port .name } "
94
98
required = port .required and required
95
- output_names = [output [1 ] for output in outputs ]
99
+ output_names = [output ["name" ] for output in outputs ]
96
100
if isinstance (port , PortNamespace ):
97
101
# TODO the default value is {} could cause problem, because the address of the dict is the same,
98
102
# so if you change the value of one port, the value of all the ports of other tasks will be changed
99
103
# consider to use None as default value
100
104
if port_name not in output_names :
101
- outputs .append ([ " General" , port_name ] )
105
+ outputs .append ({ "identifier" : " General" , "name" : port_name } )
102
106
for value in port .values ():
103
107
add_output_recursive (outputs , value , prefix = port_name , required = required )
104
108
else :
105
109
if port_name not in output_names :
106
- outputs .append ([ " General" , port_name ] )
110
+ outputs .append ({ "identifier" : " General" , "name" : port_name } )
107
111
return outputs
108
112
109
113
@@ -213,35 +217,45 @@ def build_task_from_AiiDA(
213
217
or executor .process_class ._var_positional
214
218
)
215
219
tdata ["var_kwargs" ] = name
216
- inputs .append (["General" , name , {"property" : ["General" , {"default" : {}}]}])
220
+ inputs .append (
221
+ {
222
+ "identifier" : "General" ,
223
+ "name" : name ,
224
+ "property" : {"identifier" : "General" , "default" : {}},
225
+ }
226
+ )
227
+ # TODO In order to reload the WorkGraph from process, "is_pickle" should be True
228
+ # so I pickled the function here, but this is not necessary
229
+ # we need to update the node_graph to support the path and name of the function
230
+ tdata ["identifier" ] = tdata .pop ("identifier" , tdata ["executor" ].__name__ )
231
+ tdata ["executor" ] = {
232
+ "executor" : pickle .dumps (executor ),
233
+ "type" : tdata ["task_type" ],
234
+ "is_pickle" : True ,
235
+ }
217
236
if tdata ["task_type" ].upper () in ["CALCFUNCTION" , "WORKFUNCTION" ]:
218
- outputs = [["General" , "result" ]] if not outputs else outputs
237
+ outputs = (
238
+ [{"identifier" : "General" , "name" : "result" }] if not outputs else outputs
239
+ )
240
+ # get the source code of the function
241
+ tdata ["executor" ] = serialize_function (executor )
242
+ # tdata["executor"]["type"] = tdata["task_type"]
219
243
# print("kwargs: ", kwargs)
220
244
# add built-in sockets
221
- outputs .append ([ " General" , "_outputs" ] )
222
- outputs .append ([ " General" , "_wait" ] )
223
- inputs .append ([ " General" , "_wait" , { "link_limit" : 1e6 }] )
245
+ outputs .append ({ "identifier" : " General" , "name" : " _outputs"} )
246
+ outputs .append ({ "identifier" : " General" , "name" : " _wait"} )
247
+ inputs .append ({ "identifier" : " General" , "name" : " _wait" , "link_limit" : 1e6 })
224
248
tdata ["node_class" ] = Task
225
249
tdata ["args" ] = args
226
250
tdata ["kwargs" ] = kwargs
227
251
tdata ["inputs" ] = inputs
228
252
tdata ["outputs" ] = outputs
229
- tdata ["identifier" ] = tdata .pop ("identifier" , tdata ["executor" ].__name__ )
230
- # TODO In order to reload the WorkGraph from process, "is_pickle" should be True
231
- # so I pickled the function here, but this is not necessary
232
- # we need to update the node_graph to support the path and name of the function
233
- executor = {
234
- "executor" : pickle .dumps (executor ),
235
- "type" : tdata ["task_type" ],
236
- "is_pickle" : True ,
237
- }
238
- tdata ["executor" ] = executor
239
253
task = create_task (tdata )
240
254
task .is_aiida_component = True
241
255
return task , tdata
242
256
243
257
244
- def build_python_task (func : Callable ) -> Task :
258
+ def build_pythonjob_task (func : Callable ) -> Task :
245
259
"""Build PythonJob task from function."""
246
260
from aiida_workgraph .calculations .python import PythonJob
247
261
from copy import deepcopy
@@ -254,10 +268,10 @@ def build_python_task(func: Callable) -> Task:
254
268
inputs = tdata ["inputs" ]
255
269
inputs .extend (
256
270
[
257
- [ " String" , "computer" ] ,
258
- [ " String" , "code_label" ] ,
259
- [ " String" , "code_path" ] ,
260
- [ " String" , "prepend_text" ] ,
271
+ { "identifier" : " String" , "name" : " computer"} ,
272
+ { "identifier" : " String" , "name" : " code_label"} ,
273
+ { "identifier" : " String" , "name" : " code_path"} ,
274
+ { "identifier" : " String" , "name" : " prepend_text"} ,
261
275
]
262
276
)
263
277
outputs = tdata ["outputs" ]
@@ -269,8 +283,8 @@ def build_python_task(func: Callable) -> Task:
269
283
outputs .append (output )
270
284
# change "copy_files" link_limit to 1e6
271
285
for input in inputs :
272
- if input [1 ] == "copy_files" :
273
- input [2 ]. update ({ "link_limit" : 1e6 })
286
+ if input ["name" ] == "copy_files" :
287
+ input ["link_limit" ] = 1e6
274
288
# append the kwargs of the PythonJob task to the function task
275
289
kwargs = tdata ["kwargs" ]
276
290
kwargs .extend (["computer" , "code_label" , "code_path" , "prepend_text" ])
@@ -284,7 +298,7 @@ def build_python_task(func: Callable) -> Task:
284
298
return task , tdata
285
299
286
300
287
- def build_shell_task (
301
+ def build_shelljob_task (
288
302
nodes : dict = None , outputs : list = None , parser_outputs : list = None
289
303
) -> Task :
290
304
"""Build ShellJob with custom inputs and outputs."""
@@ -300,7 +314,7 @@ def build_shell_task(
300
314
nodes = {} if nodes is None else nodes
301
315
keys = list (nodes .keys ())
302
316
for key in keys :
303
- inputs .append ([ " General" , f"nodes.{ key } " ] )
317
+ inputs .append ({ "identifier" : " General" , "name" : f"nodes.{ key } " } )
304
318
# input is a output of another task, we make a link
305
319
if isinstance (nodes [key ], NodeSocket ):
306
320
links [f"nodes.{ key } " ] = nodes [key ]
@@ -309,12 +323,20 @@ def build_shell_task(
309
323
for input in inputs :
310
324
if input not in tdata ["inputs" ]:
311
325
tdata ["inputs" ].append (input )
312
- tdata ["kwargs" ].append (input [1 ])
326
+ tdata ["kwargs" ].append (input ["name" ])
313
327
# Extend the outputs
314
- tdata ["outputs" ].extend ([["General" , "stdout" ], ["General" , "stderr" ]])
328
+ tdata ["outputs" ].extend (
329
+ [
330
+ {"identifier" : "General" , "name" : "stdout" },
331
+ {"identifier" : "General" , "name" : "stderr" },
332
+ ]
333
+ )
315
334
outputs = [] if outputs is None else outputs
316
335
parser_outputs = [] if parser_outputs is None else parser_outputs
317
- outputs = [["General" , ShellParser .format_link_label (output )] for output in outputs ]
336
+ outputs = [
337
+ {"identifier" : "General" , "name" : ShellParser .format_link_label (output )}
338
+ for output in outputs
339
+ ]
318
340
outputs .extend (parser_outputs )
319
341
# add user defined outputs
320
342
for output in outputs :
@@ -324,8 +346,8 @@ def build_shell_task(
324
346
tdata ["identifier" ] = "ShellJob"
325
347
tdata ["inputs" ].extend (
326
348
[
327
- [ " General" , "command" ] ,
328
- [ " General" , "resolve_command" ] ,
349
+ { "identifier" : " General" , "name" : " command"} ,
350
+ { "identifier" : " General" , "name" : " resolve_command"} ,
329
351
]
330
352
)
331
353
tdata ["kwargs" ].extend (["command" , "resolve_command" ])
@@ -346,25 +368,29 @@ def build_task_from_workgraph(wg: any) -> Task:
346
368
# add all the inputs/outputs from the tasks in the workgraph
347
369
for task in wg .tasks :
348
370
# inputs
349
- inputs .append ([ " General" , f"{ task .name } " ] )
371
+ inputs .append ({ "identifier" : " General" , "name" : f"{ task .name } " } )
350
372
for socket in task .inputs :
351
373
if socket .name == "_wait" :
352
374
continue
353
- inputs .append (["General" , f"{ task .name } .{ socket .name } " ])
375
+ inputs .append (
376
+ {"identifier" : "General" , "name" : f"{ task .name } .{ socket .name } " }
377
+ )
354
378
# outputs
355
- outputs .append ([ " General" , f"{ task .name } " ] )
379
+ outputs .append ({ "identifier" : " General" , "name" : f"{ task .name } " } )
356
380
for socket in task .outputs :
357
381
if socket .name in ["_wait" , "_outputs" ]:
358
382
continue
359
- outputs .append (["General" , f"{ task .name } .{ socket .name } " ])
383
+ outputs .append (
384
+ {"identifier" : "General" , "name" : f"{ task .name } .{ socket .name } " }
385
+ )
360
386
group_outputs .append (
361
387
[f"{ task .name } .{ socket .name } " , f"{ task .name } .{ socket .name } " ]
362
388
)
363
- kwargs = [input [1 ] for input in inputs ]
389
+ kwargs = [input ["name" ] for input in inputs ]
364
390
# add built-in sockets
365
- outputs .append ([ " General" , "_outputs" ] )
366
- outputs .append ([ " General" , "_wait" ] )
367
- inputs .append ([ " General" , "_wait" , { "link_limit" : 1e6 }] )
391
+ outputs .append ({ "identifier" : " General" , "name" : " _outputs"} )
392
+ outputs .append ({ "identifier" : " General" , "name" : " _wait"} )
393
+ inputs .append ({ "identifier" : " General" , "name" : " _wait" , "link_limit" : 1e6 })
368
394
tdata ["node_class" ] = Task
369
395
tdata ["kwargs" ] = kwargs
370
396
tdata ["inputs" ] = inputs
@@ -385,77 +411,6 @@ def build_task_from_workgraph(wg: any) -> Task:
385
411
return task
386
412
387
413
388
- def get_required_imports (func ):
389
- """Retrieve type hints and the corresponding module"""
390
- from typing import get_type_hints , _SpecialForm
391
-
392
- type_hints = get_type_hints (func )
393
- imports = {}
394
-
395
- def add_imports (type_hint ):
396
- if isinstance (
397
- type_hint , _SpecialForm
398
- ): # Handle special forms like Any, Union, Optional
399
- module_name = "typing"
400
- type_name = type_hint ._name or str (type_hint )
401
- elif hasattr (
402
- type_hint , "__origin__"
403
- ): # This checks for higher-order types like List, Dict
404
- module_name = type_hint .__module__
405
- type_name = type_hint ._name
406
- for arg in type_hint .__args__ :
407
- if arg is type (None ): # noqa: E721
408
- continue
409
- add_imports (arg ) # Recursively add imports for each argument
410
- elif hasattr (type_hint , "__module__" ):
411
- module_name = type_hint .__module__
412
- type_name = type_hint .__name__
413
- else :
414
- return # If no module or origin, we can't import it, e.g., for literals
415
-
416
- if module_name not in imports :
417
- imports [module_name ] = set ()
418
- imports [module_name ].add (type_name )
419
-
420
- for _ , type_hint in type_hints .items ():
421
- add_imports (type_hint )
422
-
423
- return imports
424
-
425
-
426
- def serialize_function (func : Callable ) -> Dict [str , Any ]:
427
- """Serialize a function for storage or transmission."""
428
- import cloudpickle as pickle
429
- import inspect
430
- import textwrap
431
-
432
- source_code = inspect .getsource (func )
433
- source_code_lines = source_code .split ("\n " )
434
- # we need save the source code explicitly, because in the case of jupyter notebook,
435
- # the source code is not saved in the pickle file
436
- function_source_code = "\n " .join (source_code_lines [1 :])
437
- function_source_code = textwrap .dedent (function_source_code )
438
- # we also need to include the necessary imports for the types used in the type hints.
439
- try :
440
- required_imports = get_required_imports (func )
441
- except Exception as e :
442
- required_imports = {}
443
- print (f"Failed to get required imports for function { func .__name__ } : { e } " )
444
- # Generate import statements
445
- import_statements = "\n " .join (
446
- f"from { module } import { ', ' .join (types )} "
447
- for module , types in required_imports .items ()
448
- )
449
- return {
450
- "executor" : pickle .dumps (func ),
451
- "type" : "function" ,
452
- "is_pickle" : True ,
453
- "function_name" : func .__name__ ,
454
- "function_source_code" : function_source_code ,
455
- "import_statements" : import_statements ,
456
- }
457
-
458
-
459
414
def generate_tdata (
460
415
func : Callable ,
461
416
identifier : str ,
@@ -475,9 +430,9 @@ def generate_tdata(
475
430
)
476
431
task_outputs = outputs
477
432
# add built-in sockets
478
- _inputs .append ([ " General" , "_wait" , { "link_limit" : 1e6 }] )
479
- task_outputs .append ([ " General" , "_wait" ] )
480
- task_outputs .append ([ " General" , "_outputs" ] )
433
+ _inputs .append ({ "identifier" : " General" , "name" : " _wait" , "link_limit" : 1e6 })
434
+ task_outputs .append ({ "identifier" : " General" , "name" : " _wait"} )
435
+ task_outputs .append ({ "identifier" : " General" , "name" : " _outputs"} )
481
436
tdata = {
482
437
"node_class" : Task ,
483
438
"identifier" : identifier ,
@@ -536,7 +491,7 @@ def decorator(func):
536
491
func ,
537
492
identifier ,
538
493
inputs or [],
539
- outputs or [[ " General" , "result" ] ],
494
+ outputs or [{ "identifier" : " General" , "name" : " result"} ],
540
495
properties or [],
541
496
catalog ,
542
497
task_type ,
@@ -577,7 +532,9 @@ def decorator(func):
577
532
# use cloudpickle to serialize function
578
533
func .identifier = identifier
579
534
580
- task_outputs = [["General" , output [1 ]] for output in outputs ]
535
+ task_outputs = [
536
+ {"identifier" : "General" , "name" : output [1 ]} for output in outputs
537
+ ]
581
538
# print(task_inputs, task_outputs)
582
539
#
583
540
task_type = "graph_builder"
0 commit comments