@@ -224,7 +224,7 @@ def __init__(
224
224
qk_norm : bool = False ,
225
225
attn_drop : float = 0. ,
226
226
proj_drop : float = 0. ,
227
- input_norm_layer : nn .Module = partial (LayerNorm2d , eps = 1e-5 ),
227
+ input_norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
228
228
norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
229
229
init_values : Optional [float ] = None ,
230
230
drop_path : float = 0. ,
@@ -326,7 +326,7 @@ def __init__(
326
326
qk_norm : bool = False ,
327
327
attn_drop : float = 0. ,
328
328
proj_drop : float = 0. ,
329
- input_norm_layer = LayerNorm2d ,
329
+ input_norm_layer = LayerNorm ,
330
330
norm_layer : nn .Module = LayerNorm ,
331
331
init_values : Optional [float ] = None ,
332
332
drop_path_rates : List [float ] = [0. ],
@@ -417,7 +417,7 @@ def __init__(
417
417
qk_norm : bool = False ,
418
418
attn_drop : float = 0. ,
419
419
proj_drop : float = 0. ,
420
- input_norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
420
+ input_norm_layer = partial (LayerNorm , eps = 1e-5 ),
421
421
norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
422
422
init_values : Optional [float ] = None ,
423
423
drop_path_rate : float = 0. ,
0 commit comments