@@ -60,6 +60,34 @@ def test_fusion(spec, opt_fn):
60
60
)
61
61
62
62
63
+ @pytest .mark .parametrize (
64
+ "opt_fn" , [None , simple_optimize_dag , multiple_inputs_optimize_dag ]
65
+ )
66
+ def test_fusion_compute_multiple (spec , opt_fn ):
67
+ a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
68
+ b = xp .negative (a )
69
+ c = xp .astype (b , np .float32 )
70
+ d = xp .negative (c )
71
+
72
+ # if we compute c and d then both have to be materialized
73
+ num_created_arrays = 2 # c, d
74
+ task_counter = TaskCounter ()
75
+ cubed .visualize (c , d , optimize_function = opt_fn )
76
+ c_result , d_result = cubed .compute (
77
+ c , d , optimize_function = opt_fn , callbacks = [task_counter ]
78
+ )
79
+ assert task_counter .value == num_created_arrays + 8
80
+
81
+ assert_array_equal (
82
+ c_result ,
83
+ np .array ([[- 1 , - 2 , - 3 ], [- 4 , - 5 , - 6 ], [- 7 , - 8 , - 9 ]]).astype (np .float32 ),
84
+ )
85
+ assert_array_equal (
86
+ d_result ,
87
+ np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]).astype (np .float32 ),
88
+ )
89
+
90
+
63
91
@pytest .mark .parametrize (
64
92
"opt_fn" , [None , simple_optimize_dag , multiple_inputs_optimize_dag ]
65
93
)
0 commit comments