From 7ed8b2895d191fa4069ec8476ba9cdcd1551ebed Mon Sep 17 00:00:00 2001 From: Abduragim Date: Tue, 2 Jul 2024 18:56:47 +0300 Subject: [PATCH 1/2] Test data for Attention Layer import fix --- .../input_torch_attention_single_head.npy | Bin 0 -> 368 bytes testdata/dnn/onnx/data/input_unflatten.npy | Bin 0 -> 1928 bytes .../output_torch_attention_single_head.npy | Bin 0 -> 368 bytes testdata/dnn/onnx/data/output_unflatten.npy | Bin 0 -> 1928 bytes testdata/dnn/onnx/generate_onnx_models.py | 41 ++++++++++++++++++ 5 files changed, 41 insertions(+) create mode 100644 testdata/dnn/onnx/data/input_torch_attention_single_head.npy create mode 100644 testdata/dnn/onnx/data/input_unflatten.npy create mode 100644 testdata/dnn/onnx/data/output_torch_attention_single_head.npy create mode 100644 testdata/dnn/onnx/data/output_unflatten.npy diff --git a/testdata/dnn/onnx/data/input_torch_attention_single_head.npy b/testdata/dnn/onnx/data/input_torch_attention_single_head.npy new file mode 100644 index 0000000000000000000000000000000000000000..e59a39adba61384d390783c0aff522fbd20aa983 GIT binary patch literal 368 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3TB!*3bhL40j{t2qwQ-7D(yC$%&{{lzGJJWc-ZdPvqyH9bvW(! z9oM(7n{v+XxP^!PseeiKYgX~wTNS^xb9lVRcHa9Rc1I)S>>I*8Y_?rrZRe0*VjJ+| ztL>e#8hh2dHTFAFj@vz2nPES5pQL?CeVl#Pt7mo$?2GIc%N>-ruG6W3hiG?#@h3GB-`H!m|}Obrp`Wt<)CfDf*Uq3YK-j` QG~Kh=)7oKoK`+`40O$*b00000 literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/input_unflatten.npy b/testdata/dnn/onnx/data/input_unflatten.npy new file mode 100644 index 0000000000000000000000000000000000000000..f01242c263f4f7f105d89966eb8b9fffa64e4851 GIT binary patch literal 1928 zcmbV{>t79c0*4PpY1AY|y1Nz*qFkoV@B2-3N=7B*)O1mulUu`4XQvWr+%-OKwIsK; zCSxdC5)pGUBBp6T zFN<=vm)hA!t&h>fgs%%*AEQ+Lugk-uV^yDe?5c3J>XSQ6ca+*WN$sZ3oG7*1Z1?|Q zfJ_&*T3^BdYZFR`$uaJ*IowBt^ZUcec&MMmoyk@_Q+rojwXk7tW;VQ^eNKnMkUn15{jS!BL$poHO~FDE--)l1*z|k{lMWzGyo_ zl|#6Frjn)ilc<6yI3$3_dLx7R_j*@#l-I>daKM{QIb90NeN_W?YgxQX@G z!YQ}!N3iF37@zuznH9QX#SdX3I4OzEvxd{g;|g-eEJA0{K<4aP&VPmsVABOd*cIM| zp65vV9eIOl%a1s4aIL7x9>uI-`DlvIf}4}KNKr=cHdf)s78po9==*)vMuNab)V0>_&#CNs{9O~VL^fm9rwJa0L2dBg4iyyG> z@eJ;tWy|Vu%V<@X4~0vSXk1wYlldMzuCzsY&LPBW&WrfdCqzZfH6))Ym5sDDWLZHQ z>^2;eEtf|y*=MZK_u4Nqx7Nv|sS9~4q6AY?hf=9hv-| zbi-$gDRoErps{K=OXDIaudT(3XSf5IlgOp=Hz<2!!?+#mxbgiHsQ>&{Z11y#;&&hJ z+;bj&Vm5NMEeQWIiYF&!z+i6&tacYcS!IY(+NHGnrW=YhC%EMMv-0IVRQ~l55_Pk1 z86(esvyaJp=i2?|53@O=7z3JM?_{1m?XbMSfN^PM17`%gIi} zAAByp-m{qR^(|Sv-Uc>0bx8it25P0#*mHRuGh7m>d}RrJsVVYg@6GO*Qp;!~W9}lJ0%u}hjuV*wTwp-Bm^$^OW zMl8&p%hcXj)_>;3r52-MY-7dw-<3g`7Ra)QT-?!MXVCS|1IL>=3wguR;$7&4*W?Vqepng%4H=ZVa0W8%dvbRx(QnLy43fd_& zd+n(9wTI2${n*eq8-*X*kREA_tIZO$w!5PFzBL_Ycp&B7Ov*CnQWJLrr>>7=v)qT$ zOPi<@ww{v5pWfQ6L|4@f%T09Z-N?(g^(k@dah*{Z#@bm^k$a>A zzLleS@reg}`eWhN*9+(B@8KJ`np%U|bgtB6L(o%ve3Oh96V&&-dGwmMOPoAx zNcDkpc=TcrANcQv%3~B~?U_nR%2x6ImL3)1HR8x%4ReDddGVMrn~Vg;Y*|Rjimorx zb`)aHaBB=JUrdF2CUi2KS@0y1PS+FZcF{iP`D0#p-b31J%Dc8GGzB^%o#s>02&XShW~yIvwjWZ&F-RNjGtEpXsTnW)8ogNg*CWJDkx9IZf?TzRyF2a8SzXcD3 literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_torch_attention_single_head.npy b/testdata/dnn/onnx/data/output_torch_attention_single_head.npy new file mode 100644 index 0000000000000000000000000000000000000000..d36fda7c3e22a4665a9e69028ecdb9d60bcc6a6f GIT binary patch literal 368 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3TB!*3bhL40j?#>zuG=1cHf)4PkrC2>vwH+O`G;M3x3$Ua`_+I z2q)RSR^saW=FNLyJ9~ch-V`;4ePx|rY#Zu=_WC?m-^b&B+ji^6=Dqe+ulH_p`fmHj z!f?Q+eAz2Tuh_Le()+9kBz z*yF$~x9@;8gPpw!`@V{gd-q;l<7KDt_2wQHRoQ*#O&IM?#&Yb-+p}|T+7d^*2?sCj zVLdIi@4|w=ww;|!`*JgO?foL^U{_&#Y0q{!$$h1)|7;BynD!ZqZQHwUk(*tF&5b?# Uj?3*UU-QpaO^J2iLCHgV0X^oE`v3p{ literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_unflatten.npy b/testdata/dnn/onnx/data/output_unflatten.npy new file mode 100644 index 0000000000000000000000000000000000000000..f844fd211f9d68bd1ef6826bada7def9d4607c78 GIT binary patch literal 1928 zcmb7<`#%+U0)|f!GBFY*_pVf@Qgj*4_kE=t(u7nxj7vC2H&Z9hhLlYAOrLISNjJ-A znxbq;DK#`HHO926ZCsKXiAadakapVr5BB{$zdX-xZ;JQQg}wm;BpW0Nc1l%DM6{i= zt(`2=$xdo(7r8E46CJiTbX~Mk_1`WJi;7Ww>M<+B)T&QzKg~gE>-axrOpw}cw*CL7 zj|^wFT3x~bD`QHB$}#q^8Qg}4@w>xGc&L}iok^BFQ+rojHMeGOMi#uDeNOv=WhiJ# zM$rZ}9w|-u*l-~4WXCcs+=YX>LaC$s02LQna8zpxXHU8&N`G>qWYcQrMEm)yFWin0 z174tNawCs9`1C6_o@-mWPx1`Lk z>q1;gJ?!;5#euKyi{m@`aj;piW1~6K+Uua$Y)Gk79il(XWpQLI9Q;Ao_x?Peu!;58 z!YH@vN07%j7@hix85P=M`S+nBC^3=EGl$XI{R(nMFGOeHKxXe+#(xG6VABNy*cRM{ zuEz-a9(jXmi;p;PaE++R8p+I|d1#8ugsY>MNLGgPpBphvFMw-)NZ=5|l_;C@ zjZpmHgo)L`^t$;oG!83}-!iW@i|?`T z@pSH=X~XKVOKDk`2ZeK?Xk1YU<9Y5puCzgU_94V+&WpIyCqzZfH6)!Wm5s15U|D_} zY}X%?Et7{c$$N~@^V}~ow${m{DGPWjyaba|hES!-Gc+@;%UI66Uim=Hz<2!&Db4lx$*rIsQ>szZ11yx;x}*Z z+;bkjVit0WTM+hTBu`FEhyLCUSne)_vdRD>ishd9oi(hr>X7uG^(>Z7WzXfcOm|M8@|6Yjq}oh+v`##K@HZyqg^Q}m z^EtX?JX7ZnrYhHyQQ6&iF{lVts|K=sbu=ADhG2Enems;~F;}JHzMfH>&~8qj*MljO z8nPg34pVw#SpS(Pmza-&k+mh~eOm@)Y5>c^b8tsj#?F#3>c^jh(PAlzJ9>oKys7AY zIE&Ux=P@L!1v|Ip!(r}QvCZF(J(jC6F#Q602K9@Y+;KGS^Jl3hfxV4dl$wsGXW&ku z*=tL+j~%T4?#qV0St$6>hO`JHTy2)1wcQ2H_pN9@-5tsAW>A(fhnmk((_DJa-6c*Q zHlX^zIXrqXh!6aBL*+h_GxtoPBzdcNe@mB&uo`hhN5h<;2wps9#3n<5(OVW!vb^hy z)ExzwJ!^~eod9!;AvjYY)K{=f| zS{B^}NZUBkl#^Z4j&G%xS)qqfdv+;)77)+jp`2bl=B#Pe fI`U5cGu-j*grc=N4^Vy}c2=MTPhaBeD;Y literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 07dcc0a34..a9ce31706 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -1540,6 +1540,47 @@ def forward(self, x): save_data_and_model("einsum_transpose", mat, einsum, export_params=True) +class TorchAttentionLayer(nn.Module): + def __init__(self, embed_dim=6, num_heads=1): + super(TorchAttentionLayer, self).__init__() + self.attention = nn.MultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + batch_first=True) + def forward(self, x): + return self.attention(x, x, x)[0] + +num_heads = 1 +batch_size = 2 +num_tokens = 5 +emb_dim = 6 +model = TorchAttentionLayer(embed_dim=emb_dim, num_heads=num_heads).eval() + +x = torch.rand(batch_size, num_tokens, emb_dim) +with torch.no_grad(): + output = model(x) + +save_data_and_model("torch_attention_single_head", x, model, export_params=True) +class Unflatten(torch.nn.Module): + def __init__(self, E, times): + super(Unflatten, self).__init__() + self.E = E + self.times = times + + def forward(self, x): + return x.unflatten(-1, (self.times, self.E)) + +unflatten_dim = 5 +times = 3 +model = Unflatten(unflatten_dim, times).eval() + +x = torch.rand(10, 3, unflatten_dim * times) +with torch.no_grad(): + output = model(x) + +save_data_and_model("unflatten", x, model, export_params=True) + def _extract_value_info(x, name, type_proto=None): # type: (Union[List[Any], np.ndarray, None], Text, Optional[TypeProto]) -> onnx.ValueInfoProto if type_proto is None: if x is None: From 9a9d238a4ea0b3fd1afdc14840a96233019f3034 Mon Sep 17 00:00:00 2001 From: Abduragim Date: Wed, 3 Jul 2024 11:28:51 +0300 Subject: [PATCH 2/2] add missing onnx models --- .../onnx/models/torch_attention_single_head.onnx | Bin 0 -> 6954 bytes testdata/dnn/onnx/models/unflatten.onnx | Bin 0 -> 1390 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 testdata/dnn/onnx/models/torch_attention_single_head.onnx create mode 100644 testdata/dnn/onnx/models/unflatten.onnx diff --git a/testdata/dnn/onnx/models/torch_attention_single_head.onnx b/testdata/dnn/onnx/models/torch_attention_single_head.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5be90217ede9d4fb7e8ad8b400af57fa1636f82f GIT binary patch literal 6954 zcmbtZ3v5$W817xyt{mqmmjPYXjWk9~S;jiH@o@K8kpx^Ms7wLT%sW=-bZcwdK|z%H zfIyruL6QZHP6CQL6%?aUdPIDo;v3ZX%1~6|0~aMal?3lyd++&AIp!_&;RMkwsw;s)x&z+@I3_Nw&;SpHzqz$>p3!n*&t!7_MBXJgeKLD{n=k zqLD~y`t+u_8cD>WiBNO3d>+t8zc)Q8<%xYG%$5F;s4y0acLb%NEm+Ve6^UG^$>Tze z(1ikkPZ(L-i`*Fx=cvg!o$bwm>KfS#Jji}+dN=Acna`P7gIds7Gs0DEk>*%DdSi1- zo0>ot3gvM?;U%rD)aMdb;InH|6yisFGJT92?a2TtpFyRut^$orZhICiCjjkw{x-FF zwy6c%c#UmJT$0XN7!!f|x;*OZEYurb=&X$(Zc0$=`Et5F6O=hoK~OY-&#wJKbC4*A zav2EXys#PzQ5W60BA;V;1rnfBPaoyRxZVOCn9~CFQ}U>vVxb<^qXYv61doq8-B&&b zAaYeG!NMh;SD|hRiIL{-OW{OT`?3*iZ(A&ipf1uPsK1FKI5iK%Q~x8xH=QBG zbYI3Bz1n~_0r@yA%2XTyU8d#nW!iuAWd##B)+B!5=*~C_oKhe7-u8xVbYF zS_$eC^w|{qMwlN_im|}_-UOUU5OaWoJcAH`1#sLOogMe=+Q3(!jeLdr$R|RVdhdm* zwP0uFe8h|dgAdIY05(iWumI-zX-7K|>~t0vzz&ZRX9D|12|;2T zd!UIYO>ZU`WN$DRKV0S%-80%{Kz0xIT0P-g`JpLt5A*8%ye5D4_1&ps?0UtuGQR{8GlbOY!l>z|heoW;yE?20N@V?8Cnl~vC zCWj5N`7qmMH~sHK@f&*F?)umH@qx%WsY+YDh7+$0IW zqYf|&z)ogBxz&i~5w{xETUDby{K1>5L}m0IPixr27kO37Lv2gLNi>cl((m|W4>&5$ z3Uzdt_IpR>NPLC5T)PH!sO_<^I>u(}y0}l0{3WQcC6uIppN*VJ{n~YgXPBg-K0;6T zf}p``g)`|RcS}gwoeNNwODWpt;X_@@lpEXsjKj|dq5J9f=q9oo}w+Z%V;U8@w*{oCs#cBo02 zT{({&*uD@STfaU%{_7`|c>GO#$2BjK1Cf(B<=vv}d~2IhIy4oZKWT^zyKFQ?OC z`ZKwrauTWCzMR~(={@54>rds48%~h^BWsEO(Z@+RdRX~+;4^&Tk!kq8#xlHZ-XElT z@;B+ez6#R4c^p~w={loUH>X*JQo%a<1-PWNc0u@D27ZAV*0^#Xq@z!aJ+hTlc zM%fGYL7CE2@v&QAcuux7sxbS zeCUkj@Q80R^|=okuq*Y}F4Z-;+%Ycr8Z4MVrrW;QxZXh+MGmCqs0DJXg>d`x);BM+r`2=@@xYYD zg{)+}grWu|zR6GOJ?`XsAlW)1Yx+YWOH?)#48M~#RM4-Sv{Xr$pSXk7;4(jI@Bz}Y z(QMM^7BlX7@=1u<)82TO^`QzUtnXpvC#seXb0Rm^A~$&h5_Z4ue0+6A&J2OVo822U pUq<$@+t}e#gn~g7x&bNy_C(o}dfo$uNVAY(9s_x}S36UU^gmXYkl6qL literal 0 HcmV?d00001