20
20
import torch .nn .functional as F
21
21
from torch import nn
22
22
23
- from timm .layers import to_2tuple , make_divisible , GroupNorm1 , ConvMlp , DropPath
23
+ from timm .layers import to_2tuple , make_divisible , GroupNorm1 , ConvMlp , DropPath , is_exportable
24
24
from ._builder import build_model_with_cfg
25
25
from ._features_fx import register_notrace_module
26
26
from ._registry import register_model
@@ -564,6 +564,7 @@ def __init__(
564
564
565
565
self .patch_size = to_2tuple (patch_size )
566
566
self .patch_area = self .patch_size [0 ] * self .patch_size [1 ]
567
+ self .coreml_exportable = is_exportable ()
567
568
568
569
def forward (self , x : torch .Tensor ) -> torch .Tensor :
569
570
B , C , H , W = x .shape
@@ -580,16 +581,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
580
581
581
582
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
582
583
C = x .shape [1 ]
583
- x = x .reshape (B , C , num_patch_h , patch_h , num_patch_w , patch_w ).permute (0 , 1 , 3 , 5 , 2 , 4 )
584
+ if self .coreml_exportable :
585
+ x = F .unfold (x , kernel_size = (patch_h , patch_w ), stride = (patch_h , patch_w ))
586
+ else :
587
+ x = x .reshape (B , C , num_patch_h , patch_h , num_patch_w , patch_w ).permute (0 , 1 , 3 , 5 , 2 , 4 )
584
588
x = x .reshape (B , C , - 1 , num_patches )
585
589
586
590
# Global representations
587
591
x = self .transformer (x )
588
592
x = self .norm (x )
589
593
590
594
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
591
- x = x .reshape (B , C , patch_h , patch_w , num_patch_h , num_patch_w ).permute (0 , 1 , 4 , 2 , 5 , 3 )
592
- x = x .reshape (B , C , num_patch_h * patch_h , num_patch_w * patch_w )
595
+ if self .coreml_exportable :
596
+ # adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
597
+ x = x .reshape (B , C * patch_h * patch_w , num_patch_h , num_patch_w )
598
+ x = F .pixel_shuffle (x , upscale_factor = patch_h )
599
+ else :
600
+ x = x .reshape (B , C , patch_h , patch_w , num_patch_h , num_patch_w ).permute (0 , 1 , 4 , 2 , 5 , 3 )
601
+ x = x .reshape (B , C , num_patch_h * patch_h , num_patch_w * patch_w )
602
+
593
603
594
604
x = self .conv_proj (x )
595
605
return x
0 commit comments