77from impedance .models .circuits .elements import get_element_from_name
88from impedance .models .circuits .fitting import check_and_eval , rmse
99
10+ import networkx as nx
11+ import re
12+
1013# Note: a lot of codes are directly adopted from impedance.py.,
1114# which is designed to be enable a easy integration in the future,
1215# but now we are keep them separated to ensure the stable performance
@@ -110,7 +113,6 @@ def seq_fit_param(input_dic, target_arr, output_arr):
110113
111114
112115def set_default_bounds (circuit , constants = {}):
113-
114116 """
115117 Set default bounds for optimization.
116118
@@ -170,10 +172,12 @@ def set_default_bounds(circuit, constants={}):
170172 elif raw_element in ['RCn' ] and i == 2 :
171173 upper_bounds .append (0.5 )
172174 lower_bounds .append (- 0.5 )
173- elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' ] and (i == 5 ):
175+ elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' , 'RCSQn' ,
176+ 'RCDQn' ] and (i == 5 ):
174177 upper_bounds .append (np .inf )
175178 lower_bounds .append (- np .inf )
176- elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' ] and i == 6 :
179+ elif raw_element in ['TDSn' , 'TDPn' , 'TDCn' , 'RCSQn' ,
180+ 'RCDQn' ] and i == 6 :
177181 upper_bounds .append (0.5 )
178182 lower_bounds .append (- 0.5 )
179183 elif raw_element in ['RCDn' , 'RCSn' ] and (i == 4 ):
@@ -191,6 +195,10 @@ def set_default_bounds(circuit, constants={}):
191195 elif raw_element in ['TLMSn' , 'TLMDn' ] and (i == 8 ):
192196 upper_bounds .append (np .inf )
193197 lower_bounds .append (- np .inf )
198+ elif raw_element in ['RCSQ' , 'RCSQn' ,
199+ 'RCDQ' , 'RCDQn' ] and (i == 2 ):
200+ upper_bounds .append (1 )
201+ lower_bounds .append (0 )
194202 else :
195203 upper_bounds .append (np .inf )
196204 lower_bounds .append (0 )
@@ -204,6 +212,7 @@ def set_default_bounds(circuit, constants={}):
204212
205213def circuit_fit (frequencies , impedances , circuit , initial_guess , constants = {},
206214 bounds = None , weight_by_modulus = False , global_opt = False ,
215+ graph = False ,
207216 ** kwargs ):
208217 """ Main function for fitting an equivalent circuit to data.
209218
@@ -248,6 +257,10 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
248257 If global optimization should be used (uses the basinhopping
249258 algorithm). Defaults to False
250259
260+ graph : bool, optional
261+ Whether to use execution graph to process the circuit.
262+ Defaults to False, which uses eval based code
263+
251264 kwargs :
252265 Keyword arguments passed to scipy.optimize.curve_fit or
253266 scipy.optimize.basinhopping
@@ -274,6 +287,8 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
274287 if bounds is None :
275288 bounds = set_default_bounds (circuit , constants = constants )
276289
290+ cg = CircuitGraph (circuit , constants )
291+
277292 if not global_opt :
278293 if 'maxfev' not in kwargs :
279294 kwargs ['maxfev' ] = int (1e5 )
@@ -284,10 +299,17 @@ def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
284299 if weight_by_modulus :
285300 abs_Z = np .abs (Z )
286301 kwargs ['sigma' ] = np .hstack ([abs_Z , abs_Z ])
287-
288- popt , pcov = curve_fit (wrapCircuit (circuit , constants ), f ,
289- np .hstack ([Z .real , Z .imag ]),
290- p0 = initial_guess , bounds = bounds , ** kwargs )
302+ if graph :
303+ popt , pcov = curve_fit (cg .compute_long , f ,
304+ np .hstack ([Z .real , Z .imag ]),
305+ p0 = initial_guess ,
306+ bounds = bounds ,
307+ ** kwargs ,
308+ )
309+ else :
310+ popt , pcov = curve_fit (wrapCircuit (circuit , constants ), f ,
311+ np .hstack ([Z .real , Z .imag ]),
312+ p0 = initial_guess , bounds = bounds , ** kwargs )
291313
292314 # Calculate one standard deviation error estimates for fit parameters,
293315 # defined as the square root of the diagonal of the covariance matrix.
@@ -315,6 +337,9 @@ def opt_function(x):
315337 return rmse (wrapCircuit (circuit , constants )(f , * x ),
316338 np .hstack ([Z .real , Z .imag ]))
317339
340+ def opt_function_graph (x ):
341+ return rmse (cg .compute_long (f , * x ), np .hstack ([Z .real , Z .imag ]))
342+
318343 class BasinhoppingBounds (object ):
319344 """ Adapted from the basinhopping documetation
320345 https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html
@@ -332,8 +357,12 @@ def __call__(self, **kwargs):
332357
333358 basinhopping_bounds = BasinhoppingBounds (xmin = bounds [0 ],
334359 xmax = bounds [1 ])
335- results = basinhopping (opt_function , x0 = initial_guess ,
336- accept_test = basinhopping_bounds , ** kwargs )
360+ if graph :
361+ results = basinhopping (opt_function_graph , x0 = initial_guess ,
362+ accept_test = basinhopping_bounds , ** kwargs )
363+ else :
364+ results = basinhopping (opt_function , x0 = initial_guess ,
365+ accept_test = basinhopping_bounds , ** kwargs )
337366 popt = results .x
338367
339368 # Calculate perror
@@ -578,3 +607,186 @@ def extract_circuit_elements(circuit):
578607 current_element .append (char )
579608 extracted_elements .append ('' .join (current_element ))
580609 return extracted_elements
610+
611+ # Circuit Graph for computation optimization
612+ # Special Thanks to Jake Anderson for the original code
613+
614+
615+ class CircuitGraph :
616+ '''
617+ A class to represent a circuit as a directed graph.
618+ '''
619+ # regular expression to find parallel and difference blocks
620+ _parallel_difference_block_expression = re .compile (r'(?:p|d)\([^()]*\)' )
621+
622+ # regular expression to remove whitespace
623+ _whitespce = re .compile (r"\s+" )
624+
625+ def __init__ (self , circuit , constants = None ):
626+ '''
627+ Initialize the CircuitGraph object.'''
628+ # remove all whitespace from the circuit string
629+ self .circuit = self ._whitespce .sub ("" , circuit )
630+ # parse the circuit string and initialize the graph
631+ self .parse_circuit ()
632+ # compute the execution order of the graph
633+ self .execution_order = list (nx .topological_sort (self .graph ))
634+ # initialize the constants dictionary
635+ self .constants = constants if constants is not None else dict ()
636+
637+ def parse_circuit (self ):
638+ '''
639+ Parse the circuit string and initialize the graph.
640+ '''
641+ # initialize the node counters for each type of block
642+ self .snum = 1
643+ self .pnum = 1
644+ self .dnum = 1
645+ # initialize the circuit string to be parsed
646+ parsing_circuit = self .circuit
647+
648+ # determine all of the base elements, their functions
649+ # and add them to the graph
650+ element_name = extract_circuit_elements (parsing_circuit )
651+ element_func = [
652+ circuit_elements [get_element_from_name (e )] for e in element_name
653+ ]
654+ # graph initialization
655+ self .graph = nx .DiGraph ()
656+ # add nodes to the graph
657+ for e , f in zip (element_name , element_func ):
658+ self .graph .add_node (e , Z = f )
659+
660+ # find unnested parallel and difference blocks
661+ pd_blocks = self ._parallel_difference_block_expression .findall (
662+ parsing_circuit )
663+
664+ while len (pd_blocks ) > 0 :
665+ # add parallel or difference blocks to the graph
666+ # unnesting each time around the loop
667+ for pd in pd_blocks :
668+ operator = pd [0 ]
669+ pd_elem = pd [2 :- 1 ].split ("," )
670+
671+ if operator == "p" :
672+ nnum = self .pnum
673+ self .pnum += 1
674+ elif operator == "d" :
675+ nnum = self .dnum
676+ self .dnum += 1
677+
678+ node = f"{ operator } { nnum } "
679+ self .graph .add_node (node , Z = circuit_elements [operator ])
680+ for elem in pd_elem :
681+ elem = self .add_series_elements (elem )
682+ self .graph .add_edge (elem , node )
683+ parsing_circuit = parsing_circuit .replace (pd , node )
684+
685+ pd_blocks = self ._parallel_difference_block_expression .findall (
686+ parsing_circuit )
687+
688+ # pick up any top line series connections
689+ self .add_series_elements (parsing_circuit )
690+
691+ # assign layers to the nodes
692+ for layer , nodes in enumerate (nx .topological_generations (self .graph )):
693+ for n in nodes :
694+ self .graph .nodes [n ]["layer" ] = layer
695+ # function to add series elements to the graph
696+
697+ def add_series_elements (self , elem ):
698+ '''
699+ Add series elements to the graph.
700+ '''
701+ selem = elem .split ("-" )
702+ if len (selem ) > 1 :
703+ node = f"s{ self .snum } "
704+ self .snum += 1
705+ self .graph .add_node (node , Z = circuit_elements ["s" ])
706+ for n in selem :
707+ self .graph .add_edge (n , node )
708+ return node
709+
710+ # if there isn't a series connection in elem just return it unchanged
711+ return selem [0 ]
712+
713+ # function to visualize the graph
714+ def visualize_graph (self , ** kwargs ):
715+ '''
716+ Visualize the graph.'''
717+ pos = nx .multipartite_layout (self .graph , subset_key = "layer" )
718+ nx .draw_networkx (self .graph , pos = pos , ** kwargs )
719+
720+ # function to compute the impedance of the circuit
721+ def compute (self , f , * parameters ):
722+ '''
723+ Compute the impedance of the circuit at the given frequencies.
724+ '''
725+ node_results = {}
726+ pindex = 0
727+ for node in self .execution_order :
728+ Zfunc = self .graph .nodes [node ]["Z" ]
729+ plist = [
730+ node_results [pred ] for pred in self .graph .predecessors (node )
731+ ]
732+
733+ if len (plist ) < 1 :
734+ n_params = Zfunc .num_params
735+ for j in range (n_params ):
736+ p_name = format_parameter_name (node , j , n_params )
737+ if p_name in self .constants :
738+ plist .append (self .constants [p_name ])
739+ else :
740+ plist .append (parameters [pindex ])
741+ pindex += 1
742+ node_results [node ] = Zfunc (plist , f )
743+ else :
744+ node_results [node ] = Zfunc (plist )
745+
746+ return np .squeeze (node_results [node ])
747+
748+ # To enable comparision
749+
750+ def __eq__ (self , other ):
751+ '''
752+ Compare two CircuitGraph objects for equality.
753+ '''
754+ if not isinstance (other , CircuitGraph ):
755+ return False
756+ # Compare the internal graph attributes
757+ return (self .graph .nodes == other .graph .nodes
758+ and self .graph .edges == other .graph .edges )
759+
760+ # To enable direct calling
761+
762+ def __call__ (self , f , * parameters ):
763+ '''
764+ Compute the impedance of the circuit at the given frequencies.
765+ '''
766+ Z = self .compute (f , * parameters )
767+ return Z
768+
769+ def compute_long (self , f , * parameters ):
770+ '''
771+ Compute the impedance of the circuit at the given frequencies.
772+ And convert it to a long array for curve_fit.
773+ '''
774+ Z = self .compute (f , * parameters )
775+ return np .hstack ([Z .real , Z .imag ])
776+
777+ def calculate_circuit_length (self ):
778+ '''
779+ calculate the number of parameters in the circuit
780+ '''
781+ n_params = [
782+ getattr (Zfunc , "num_params" , 0 )
783+ for node , Zfunc in self .graph .nodes (data = "Z" )
784+ ]
785+ return np .sum (n_params )
786+
787+
788+ def format_parameter_name (name , j , n_params ):
789+ '''
790+ Format the parameter name for the given element.
791+ '''
792+ return f"{ name } _{ j } " if n_params > 1 else f"{ name } "
0 commit comments