22
22
import timeit
23
23
import matplotlib .pyplot as plt
24
24
import pennylane as qml
25
- import jax
26
-
27
- jax .config .update ("jax_platform_name" , "cpu" )
28
- jax .config .update ('jax_enable_x64' , True )
25
+ import pennylane .numpy as pnp
29
26
30
27
plt .style .use ("bmh" )
31
28
32
29
n_samples = 5
33
30
34
31
35
32
def get_time (qnode , params ):
36
- globals_dict = {' grad' : jax .grad , ' circuit' : qnode , ' params' : params }
33
+ globals_dict = {" grad" : qml .grad , " circuit" : qnode , " params" : params }
37
34
return timeit .timeit ("grad(circuit)(params)" , globals = globals_dict , number = n_samples )
38
35
39
36
40
37
def wires_scaling (n_wires , n_layers ):
41
- key = jax .random .PRNGKey ( 42 )
38
+ rng = pnp .random .default_rng ( 12345 )
42
39
43
40
t_adjoint = []
44
41
t_ps = []
@@ -58,7 +55,7 @@ def circuit(params, wires):
58
55
59
56
# set up the parameters
60
57
param_shape = qml .StronglyEntanglingLayers .shape (n_wires = i_wires , n_layers = n_layers )
61
- params = jax . random . normal (key , param_shape )
58
+ params = rng . normal (size = pnp . prod ( param_shape ), requires_grad = True ). reshape ( param_shape )
62
59
63
60
t_adjoint .append (get_time (circuit_adjoint , params ))
64
61
t_backprop .append (get_time (circuit_backprop , params ))
@@ -68,10 +65,10 @@ def circuit(params, wires):
68
65
69
66
70
67
def layers_scaling (n_wires , n_layers ):
71
- key = jax .random .PRNGKey ( 42 )
68
+ rng = pnp .random .default_rng ( 12345 )
72
69
73
70
dev = qml .device ("lightning.qubit" , wires = n_wires )
74
- dev_python = qml .device (' default.qubit' , wires = n_wires )
71
+ dev_python = qml .device (" default.qubit" , wires = n_wires )
75
72
76
73
t_adjoint = []
77
74
t_ps = []
@@ -88,7 +85,7 @@ def circuit(params):
88
85
for i_layers in n_layers :
89
86
# set up the parameters
90
87
param_shape = qml .StronglyEntanglingLayers .shape (n_wires = n_wires , n_layers = i_layers )
91
- params = jax . random . normal (key , param_shape )
88
+ params = rng . normal (size = pnp . prod ( param_shape ), requires_grad = True ). reshape ( param_shape )
92
89
93
90
t_adjoint .append (get_time (circuit_adjoint , params ))
94
91
t_backprop .append (get_time (circuit_backprop , params ))
@@ -110,9 +107,9 @@ def circuit(params):
110
107
# Generating the graphic
111
108
fig , (ax1 , ax2 ) = plt .subplots (1 , 2 , figsize = (10 , 4 ))
112
109
113
- ax1 .plot (wires_list , adjoint_wires , '.-' , label = "adjoint" )
114
- ax1 .plot (wires_list , ps_wires , '.-' , label = "parameter-shift" )
115
- ax1 .plot (wires_list , backprop_wires , '.-' , label = "backprop" )
110
+ ax1 .plot (wires_list , adjoint_wires , ".-" , label = "adjoint" )
111
+ ax1 .plot (wires_list , ps_wires , ".-" , label = "parameter-shift" )
112
+ ax1 .plot (wires_list , backprop_wires , ".-" , label = "backprop" )
116
113
117
114
ax1 .legend ()
118
115
@@ -122,16 +119,17 @@ def circuit(params):
122
119
ax1 .set_yscale ("log" )
123
120
ax1 .set_title ("Scaling with wires" )
124
121
125
- ax2 .plot (layers_list , adjoint_layers , '.-' , label = "adjoint" )
126
- ax2 .plot (layers_list , ps_layers , '.-' , label = "parameter-shift" )
127
- ax2 .plot (layers_list , backprop_layers , '.-' , label = "backprop" )
122
+ ax2 .plot (layers_list , adjoint_layers , ".-" , label = "adjoint" )
123
+ ax2 .plot (layers_list , ps_layers , ".-" , label = "parameter-shift" )
124
+ ax2 .plot (layers_list , backprop_layers , ".-" , label = "backprop" )
128
125
129
126
ax2 .legend ()
130
127
131
128
ax2 .set_xlabel ("Number of layers" )
132
129
ax2 .set_xticks (layers_list )
133
- ax2 .set_ylabel ("Time" )
134
- ax2 .set_title ("Scaling with Layers" )
130
+ ax2 .set_ylabel ("Log Time" )
131
+ ax2 .set_yscale ("log" )
132
+ ax2 .set_title ("Scaling with layers" )
135
133
136
134
plt .savefig ("scaling.png" )
137
135
0 commit comments