Skip to content

Commit 025d8a4

Browse files
committed
Update cvt.py
1 parent df05c0d commit 025d8a4

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

timm/models/cvt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
embed_kernel_size: int = 7,
315315
embed_stride: int = 4,
316316
embed_padding: int = 2,
317+
embed_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5),
317318
kernel_size: int = 3,
318319
stride_q: int = 1,
319320
stride_kv: int = 2,
@@ -405,6 +406,7 @@ def __init__(
405406
embed_kernel_size: Tuple[int, ...] = (7, 3, 3),
406407
embed_stride: Tuple[int, ...] = (4, 2, 2),
407408
embed_padding: Tuple[int, ...] = (2, 1, 1),
409+
embed_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5),
408410
kernel_size: int = 3,
409411
stride_q: int = 1,
410412
stride_kv: int = 2,
@@ -452,6 +454,7 @@ def __init__(
452454
embed_kernel_size = embed_kernel_size[stage_idx],
453455
embed_stride = embed_stride[stage_idx],
454456
embed_padding = embed_padding[stage_idx],
457+
embed_norm_layer = embed_norm_layer,
455458
kernel_size = kernel_size,
456459
stride_q = stride_q,
457460
stride_kv = stride_kv,

0 commit comments

Comments
 (0)