Skip to content

Commit 992bf7c

Browse files
committed
chore: Modify the MobileVitV2Block to be coreml exportable
based on is_exportable() set variable controlling behaviour of the block CoreMLTools support im2col from 6.2 version, unfortunately col2im is still not supported. Tested with exporting to ONNX, Torchscript, CoreML, and TVM.
1 parent 4b8cfa6 commit 992bf7c

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

timm/models/mobilevit.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn.functional as F
2121
from torch import nn
2222

23-
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
23+
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
2424
from ._builder import build_model_with_cfg
2525
from ._features_fx import register_notrace_module
2626
from ._registry import register_model
@@ -564,6 +564,7 @@ def __init__(
564564

565565
self.patch_size = to_2tuple(patch_size)
566566
self.patch_area = self.patch_size[0] * self.patch_size[1]
567+
self.coreml_exportable = is_exportable()
567568

568569
def forward(self, x: torch.Tensor) -> torch.Tensor:
569570
B, C, H, W = x.shape
@@ -580,16 +581,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
580581

581582
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
582583
C = x.shape[1]
583-
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
584+
if self.coreml_exportable:
585+
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
586+
else:
587+
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
584588
x = x.reshape(B, C, -1, num_patches)
585589

586590
# Global representations
587591
x = self.transformer(x)
588592
x = self.norm(x)
589593

590594
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
591-
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
592-
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
595+
if self.coreml_exportable:
596+
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
597+
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
598+
x = F.pixel_shuffle(x, upscale_factor=patch_h)
599+
else:
600+
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
601+
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
602+
593603

594604
x = self.conv_proj(x)
595605
return x

0 commit comments

Comments
 (0)