7
7
from impedance .models .circuits .elements import get_element_from_name
8
8
from impedance .models .circuits .fitting import check_and_eval , rmse
9
9
10
+ import networkx as nx
11
+ import re
12
+
10
13
# Note: a lot of codes are directly adopted from impedance.py.,
11
14
# which is designed to be enable a easy integration in the future,
12
15
# but now we are keep them separated to ensure the stable performance
@@ -110,7 +113,6 @@ def seq_fit_param(input_dic, target_arr, output_arr):
110
113
111
114
112
115
def set_default_bounds (circuit , constants = {}):
113
-
114
116
"""
115
117
Set default bounds for optimization.
116
118
@@ -170,10 +172,12 @@ def set_default_bounds(circuit, constants={}):
170
172
elif raw_element in ['RCn' ] and i == 2 :
171
173
upper_bounds .append (0.5 )
172
174
lower_bounds .append (- 0.5 )
173
- elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' ] and (i == 5 ):
175
+ elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' , 'RCSQn' ,
176
+ 'RCDQn' ] and (i == 5 ):
174
177
upper_bounds .append (np .inf )
175
178
lower_bounds .append (- np .inf )
176
- elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' ] and i == 6 :
179
+ elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' , 'RCSQn' ,
180
+ 'RCDQn' ] and i == 6 :
177
181
upper_bounds .append (0.5 )
178
182
lower_bounds .append (- 0.5 )
179
183
elif raw_element in ['RCDn' , 'RCSn' ] and (i == 4 ):
@@ -191,6 +195,10 @@ def set_default_bounds(circuit, constants={}):
191
195
elif raw_element in ['TLMSn' , 'TLMDn' ] and (i == 8 ):
192
196
upper_bounds .append (np .inf )
193
197
lower_bounds .append (- np .inf )
198
+ elif raw_element in ['RCSQ' , 'RCSQn' ,
199
+ 'RCDQ' , 'RCDQn' ] and (i == 2 ):
200
+ upper_bounds .append (1 )
201
+ lower_bounds .append (0 )
194
202
else :
195
203
upper_bounds .append (np .inf )
196
204
lower_bounds .append (0 )
@@ -204,6 +212,7 @@ def set_default_bounds(circuit, constants={}):
204
212
205
213
def circuit_fit (frequencies , impedances , circuit , initial_guess , constants = {},
206
214
bounds = None , weight_by_modulus = False , global_opt = False ,
215
+ graph = False ,
207
216
** kwargs ):
208
217
""" Main function for fitting an equivalent circuit to data.
209
218
@@ -248,6 +257,10 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
248
257
If global optimization should be used (uses the basinhopping
249
258
algorithm). Defaults to False
250
259
260
+ graph : bool, optional
261
+ Whether to use execution graph to process the circuit.
262
+ Defaults to False, which uses eval based code
263
+
251
264
kwargs :
252
265
Keyword arguments passed to scipy.optimize.curve_fit or
253
266
scipy.optimize.basinhopping
@@ -274,6 +287,8 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
274
287
if bounds is None :
275
288
bounds = set_default_bounds (circuit , constants = constants )
276
289
290
+ cg = CircuitGraph (circuit , constants )
291
+
277
292
if not global_opt :
278
293
if 'maxfev' not in kwargs :
279
294
kwargs ['maxfev' ] = int (1e5 )
@@ -284,10 +299,17 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
284
299
if weight_by_modulus :
285
300
abs_Z = np .abs (Z )
286
301
kwargs ['sigma' ] = np .hstack ([abs_Z , abs_Z ])
287
-
288
- popt , pcov = curve_fit (wrapCircuit (circuit , constants ), f ,
289
- np .hstack ([Z .real , Z .imag ]),
290
- p0 = initial_guess , bounds = bounds , ** kwargs )
302
+ if graph :
303
+ popt , pcov = curve_fit (cg .compute_long , f ,
304
+ np .hstack ([Z .real , Z .imag ]),
305
+ p0 = initial_guess ,
306
+ bounds = bounds ,
307
+ ** kwargs ,
308
+ )
309
+ else :
310
+ popt , pcov = curve_fit (wrapCircuit (circuit , constants ), f ,
311
+ np .hstack ([Z .real , Z .imag ]),
312
+ p0 = initial_guess , bounds = bounds , ** kwargs )
291
313
292
314
# Calculate one standard deviation error estimates for fit parameters,
293
315
# defined as the square root of the diagonal of the covariance matrix.
@@ -315,6 +337,9 @@ def opt_function(x):
315
337
return rmse (wrapCircuit (circuit , constants )(f , * x ),
316
338
np .hstack ([Z .real , Z .imag ]))
317
339
340
+ def opt_function_graph (x ):
341
+ return rmse (cg .compute_long (f , * x ), np .hstack ([Z .real , Z .imag ]))
342
+
318
343
class BasinhoppingBounds (object ):
319
344
""" Adapted from the basinhopping documetation
320
345
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html
@@ -332,8 +357,12 @@ def __call__(self, **kwargs):
332
357
333
358
basinhopping_bounds = BasinhoppingBounds (xmin = bounds [0 ],
334
359
xmax = bounds [1 ])
335
- results = basinhopping (opt_function , x0 = initial_guess ,
336
- accept_test = basinhopping_bounds , ** kwargs )
360
+ if graph :
361
+ results = basinhopping (opt_function_graph , x0 = initial_guess ,
362
+ accept_test = basinhopping_bounds , ** kwargs )
363
+ else :
364
+ results = basinhopping (opt_function , x0 = initial_guess ,
365
+ accept_test = basinhopping_bounds , ** kwargs )
337
366
popt = results .x
338
367
339
368
# Calculate perror
@@ -578,3 +607,186 @@ def extract_circuit_elements(circuit):
578
607
current_element .append (char )
579
608
extracted_elements .append ('' .join (current_element ))
580
609
return extracted_elements
610
+
611
+ # Circuit Graph for computation optimization
612
+ # Special Thanks to Jake Anderson for the original code
613
+
614
+
615
+ class CircuitGraph :
616
+ '''
617
+ A class to represent a circuit as a directed graph.
618
+ '''
619
+ # regular expression to find parallel and difference blocks
620
+ _parallel_difference_block_expression = re .compile (r'(?:p|d)\([^()]*\)' )
621
+
622
+ # regular expression to remove whitespace
623
+ _whitespce = re .compile (r"\s+" )
624
+
625
+ def __init__ (self , circuit , constants = None ):
626
+ '''
627
+ Initialize the CircuitGraph object.'''
628
+ # remove all whitespace from the circuit string
629
+ self .circuit = self ._whitespce .sub ("" , circuit )
630
+ # parse the circuit string and initialize the graph
631
+ self .parse_circuit ()
632
+ # compute the execution order of the graph
633
+ self .execution_order = list (nx .topological_sort (self .graph ))
634
+ # initialize the constants dictionary
635
+ self .constants = constants if constants is not None else dict ()
636
+
637
+ def parse_circuit (self ):
638
+ '''
639
+ Parse the circuit string and initialize the graph.
640
+ '''
641
+ # initialize the node counters for each type of block
642
+ self .snum = 1
643
+ self .pnum = 1
644
+ self .dnum = 1
645
+ # initialize the circuit string to be parsed
646
+ parsing_circuit = self .circuit
647
+
648
+ # determine all of the base elements, their functions
649
+ # and add them to the graph
650
+ element_name = extract_circuit_elements (parsing_circuit )
651
+ element_func = [
652
+ circuit_elements [get_element_from_name (e )] for e in element_name
653
+ ]
654
+ # graph initialization
655
+ self .graph = nx .DiGraph ()
656
+ # add nodes to the graph
657
+ for e , f in zip (element_name , element_func ):
658
+ self .graph .add_node (e , Z = f )
659
+
660
+ # find unnested parallel and difference blocks
661
+ pd_blocks = self ._parallel_difference_block_expression .findall (
662
+ parsing_circuit )
663
+
664
+ while len (pd_blocks ) > 0 :
665
+ # add parallel or difference blocks to the graph
666
+ # unnesting each time around the loop
667
+ for pd in pd_blocks :
668
+ operator = pd [0 ]
669
+ pd_elem = pd [2 :- 1 ].split ("," )
670
+
671
+ if operator == "p" :
672
+ nnum = self .pnum
673
+ self .pnum += 1
674
+ elif operator == "d" :
675
+ nnum = self .dnum
676
+ self .dnum += 1
677
+
678
+ node = f"{ operator } { nnum } "
679
+ self .graph .add_node (node , Z = circuit_elements [operator ])
680
+ for elem in pd_elem :
681
+ elem = self .add_series_elements (elem )
682
+ self .graph .add_edge (elem , node )
683
+ parsing_circuit = parsing_circuit .replace (pd , node )
684
+
685
+ pd_blocks = self ._parallel_difference_block_expression .findall (
686
+ parsing_circuit )
687
+
688
+ # pick up any top line series connections
689
+ self .add_series_elements (parsing_circuit )
690
+
691
+ # assign layers to the nodes
692
+ for layer , nodes in enumerate (nx .topological_generations (self .graph )):
693
+ for n in nodes :
694
+ self .graph .nodes [n ]["layer" ] = layer
695
+ # function to add series elements to the graph
696
+
697
+ def add_series_elements (self , elem ):
698
+ '''
699
+ Add series elements to the graph.
700
+ '''
701
+ selem = elem .split ("-" )
702
+ if len (selem ) > 1 :
703
+ node = f"s{ self .snum } "
704
+ self .snum += 1
705
+ self .graph .add_node (node , Z = circuit_elements ["s" ])
706
+ for n in selem :
707
+ self .graph .add_edge (n , node )
708
+ return node
709
+
710
+ # if there isn't a series connection in elem just return it unchanged
711
+ return selem [0 ]
712
+
713
+ # function to visualize the graph
714
+ def visualize_graph (self , ** kwargs ):
715
+ '''
716
+ Visualize the graph.'''
717
+ pos = nx .multipartite_layout (self .graph , subset_key = "layer" )
718
+ nx .draw_networkx (self .graph , pos = pos , ** kwargs )
719
+
720
+ # function to compute the impedance of the circuit
721
+ def compute (self , f , * parameters ):
722
+ '''
723
+ Compute the impedance of the circuit at the given frequencies.
724
+ '''
725
+ node_results = {}
726
+ pindex = 0
727
+ for node in self .execution_order :
728
+ Zfunc = self .graph .nodes [node ]["Z" ]
729
+ plist = [
730
+ node_results [pred ] for pred in self .graph .predecessors (node )
731
+ ]
732
+
733
+ if len (plist ) < 1 :
734
+ n_params = Zfunc .num_params
735
+ for j in range (n_params ):
736
+ p_name = format_parameter_name (node , j , n_params )
737
+ if p_name in self .constants :
738
+ plist .append (self .constants [p_name ])
739
+ else :
740
+ plist .append (parameters [pindex ])
741
+ pindex += 1
742
+ node_results [node ] = Zfunc (plist , f )
743
+ else :
744
+ node_results [node ] = Zfunc (plist )
745
+
746
+ return np .squeeze (node_results [node ])
747
+
748
+ # To enable comparision
749
+
750
+ def __eq__ (self , other ):
751
+ '''
752
+ Compare two CircuitGraph objects for equality.
753
+ '''
754
+ if not isinstance (other , CircuitGraph ):
755
+ return False
756
+ # Compare the internal graph attributes
757
+ return (self .graph .nodes == other .graph .nodes
758
+ and self .graph .edges == other .graph .edges )
759
+
760
+ # To enable direct calling
761
+
762
+ def __call__ (self , f , * parameters ):
763
+ '''
764
+ Compute the impedance of the circuit at the given frequencies.
765
+ '''
766
+ Z = self .compute (f , * parameters )
767
+ return Z
768
+
769
+ def compute_long (self , f , * parameters ):
770
+ '''
771
+ Compute the impedance of the circuit at the given frequencies.
772
+ And convert it to a long array for curve_fit.
773
+ '''
774
+ Z = self .compute (f , * parameters )
775
+ return np .hstack ([Z .real , Z .imag ])
776
+
777
+ def calculate_circuit_length (self ):
778
+ '''
779
+ calculate the number of parameters in the circuit
780
+ '''
781
+ n_params = [
782
+ getattr (Zfunc , "num_params" , 0 )
783
+ for node , Zfunc in self .graph .nodes (data = "Z" )
784
+ ]
785
+ return np .sum (n_params )
786
+
787
+
788
+ def format_parameter_name (name , j , n_params ):
789
+ '''
790
+ Format the parameter name for the given element.
791
+ '''
792
+ return f"{ name } _{ j } " if n_params > 1 else f"{ name } "
0 commit comments