2
2
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
3
3
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4
4
# Yasushi Esaki <esakiful@gmail.com>
5
+ # Jun Nishikawa <jun.b.nishikawa@gmail.com>
5
6
import numpy as np
6
7
7
8
_EPSILON = np .sqrt (np .finfo (np .float64 ).eps )
@@ -230,27 +231,8 @@ def onehot_vecs(val,val_name,exception_class):
230
231
return val
231
232
raise (exception_class (val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors." ))
232
233
233
- def dim_consistency (value_dict : dict , exception_class ):
234
- check_value_dict = {}
235
- for key in value_dict :
236
- if value_dict [key ] is not None and value_dict [key ] not in check_value_dict .values ():
237
- check_value_dict [key ] = value_dict [key ]
238
- if len (check_value_dict ) == 0 :
239
- return None
240
- elif len (check_value_dict ) > 1 :
241
- message = f"The following values must be the same: { list (value_dict .keys ())} . "
242
- message += f"The following values are different: { list (check_value_dict .keys ())} . "
243
- # print("===== Error =====")
244
- for key in check_value_dict :
245
- # print(f"{key} = {check_value_dict[key]}")
246
- message += f"\n { key } = { check_value_dict [key ]} "
247
- else :
248
- return list (check_value_dict .values ())[0 ]
249
- raise (exception_class (message ))
250
-
251
- def shape_consistency (val , val_name , correct , correct_name , exception_class ):
252
- if val .shape not in correct :
253
- message = f"{ val_name } .shape must coincide with { correct_name } :"
254
- + f"{ val_name } .shape={ val .shape } , { correct_name } ={ correct } "
255
- raise (exception_class (message ))
256
-
234
+ def shape_consistency (val : int , val_name : str , correct : int , correct_name : str , exception_class ):
235
+ if val != correct :
236
+ message = (f"{ val_name } must coincide with { correct_name } : "
237
+ + f"{ val_name } = { val } , { correct_name } = { correct } " )
238
+ raise (exception_class (message ))
0 commit comments