@@ -144,7 +144,12 @@ def standardization(data):
144144 """standardization numpy array"""
145145 mu = np .mean (data , axis = 0 )
146146 sigma = np .std (data , axis = 0 )
147- sigma = 1e-13 if sigma == 0. else sigma
147+ if isinstance (sigma , list ) or isinstance (sigma , np .ndarray ):
148+ for idx , sig in enumerate (sigma ):
149+ if sig == 0. :
150+ sigma [idx ] = 1e-13
151+ else :
152+ sigma = 1e-13 if sigma == 0. else sigma
148153 return (data - mu ) / sigma
149154
150155
@@ -241,18 +246,15 @@ def eval_quant_model():
241246 if have_invalid_num (out_float ) or have_invalid_num (out_quant ):
242247 continue
243248
244- try :
245- out_float = standardization (out_float )
246- out_quant = standardization (out_quant )
247- except :
248- continue
249- out_float_list .append (out_float )
250- out_quant_list .append (out_quant )
249+ out_float_list .append (list (out_float ))
250+ out_quant_list .append (list (out_quant ))
251251 valid_data_num += 1
252252
253253 if valid_data_num >= max_eval_data_num :
254254 break
255255
256+ out_float_list = standardization (out_float_list )
257+ out_quant_list = standardization (out_quant_list )
256258 emd_sum = cal_emd_lose (out_float_list , out_quant_list ,
257259 out_len_sum / float (valid_data_num ))
258260 _logger .info ("output diff: {}" .format (emd_sum ))
0 commit comments