11use paste:: paste;
2+ use std:: arch:: asm;
3+ use std:: fmt:: { Debug , Formatter } ;
4+ use std:: ops;
25use std:: ops:: { Index , IndexMut } ;
3- use std:: { mem, ops} ;
46
57#[ derive( Copy , Clone ) ]
68pub struct Matrix ( [ i32 ; 16 ] ) ;
@@ -23,6 +25,16 @@ impl ops::Mul for Matrix {
2325 }
2426}
2527
28+ impl Debug for Matrix {
29+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
30+ let mut debug = f. debug_list ( ) ;
31+ for n in & self . 0 {
32+ debug. entry ( n) ;
33+ }
34+ debug. finish ( )
35+ }
36+ }
37+
2638impl AsRef < [ i32 ; 16 ] > for Matrix {
2739 fn as_ref ( & self ) -> & [ i32 ; 16 ] {
2840 & self . 0
@@ -69,7 +81,7 @@ macro_rules! define_vector {
6981
7082 impl <const SIZE : usize > Default for [ <Vector $t>] <SIZE > {
7183 fn default ( ) -> Self {
72- unsafe { mem :: zeroed ( ) }
84+ [ < Vector $t> ] ( [ $t :: default ( ) ; SIZE ] )
7385 }
7486 }
7587
@@ -109,7 +121,7 @@ macro_rules! define_vector {
109121 self
110122 }
111123 }
112-
124+
113125 impl From <[ <Vector $t>] <3 >> for [ <Vector $t>] <4 > {
114126 fn from( value: [ <Vector $t>] <3 >) -> Self {
115127 let mut ret = Self :: default ( ) ;
@@ -119,13 +131,24 @@ macro_rules! define_vector {
119131 ret
120132 }
121133 }
134+
135+ impl <const SIZE : usize > Debug for [ <Vector $t>] <SIZE > {
136+ fn fmt( & self , f: & mut Formatter <' _>) -> std:: fmt:: Result {
137+ let mut debug = f. debug_list( ) ;
138+ for n in & self . 0 {
139+ debug. entry( n) ;
140+ }
141+ debug. finish( )
142+ }
143+ }
122144 }
123145 } ;
124146}
125147
126148define_vector ! ( u16 ) ;
127149define_vector ! ( i16 ) ;
128150define_vector ! ( i32 ) ;
151+ define_vector ! ( f32 ) ;
129152
130153impl ops:: Mul < Matrix > for Vectori32 < 3 > {
131154 type Output = Self ;
@@ -144,10 +167,33 @@ impl ops::Mul<Matrix> for Vectori32<4> {
144167
145168 fn mul ( self , rhs : Matrix ) -> Self :: Output {
146169 let mut ret = Vectori32 :: default ( ) ;
147- ret[ 0 ] = ( ( self [ 0 ] as i64 * rhs[ 0 ] as i64 + self [ 1 ] as i64 * rhs[ 4 ] as i64 + self [ 2 ] as i64 * rhs[ 8 ] as i64 + self [ 3 ] as i64 * rhs[ 12 ] as i64 ) >> 12 ) as i32 ;
148- ret[ 1 ] = ( ( self [ 0 ] as i64 * rhs[ 1 ] as i64 + self [ 1 ] as i64 * rhs[ 5 ] as i64 + self [ 2 ] as i64 * rhs[ 9 ] as i64 + self [ 3 ] as i64 * rhs[ 13 ] as i64 ) >> 12 ) as i32 ;
149- ret[ 2 ] = ( ( self [ 0 ] as i64 * rhs[ 2 ] as i64 + self [ 1 ] as i64 * rhs[ 6 ] as i64 + self [ 2 ] as i64 * rhs[ 10 ] as i64 + self [ 3 ] as i64 * rhs[ 14 ] as i64 ) >> 12 ) as i32 ;
150- ret[ 3 ] = ( ( self [ 0 ] as i64 * rhs[ 3 ] as i64 + self [ 1 ] as i64 * rhs[ 7 ] as i64 + self [ 2 ] as i64 * rhs[ 11 ] as i64 + self [ 3 ] as i64 * rhs[ 15 ] as i64 ) >> 12 ) as i32 ;
170+ unsafe {
171+ asm ! (
172+ "vld1.s32 {{q0}}, [{v}]" ,
173+ "vld1.s32 {{q1}}, [{m}]!" ,
174+ "vld1.s32 {{q2}}, [{m}]!" ,
175+ "vld1.s32 {{q3}}, [{m}]!" ,
176+ "vld1.s32 {{q4}}, [{m}]" ,
177+ "vmull.s32 q5, d2, d0[0]" ,
178+ "vmlal.s32 q5, d4, d0[1]" ,
179+ "vmlal.s32 q5, d6, d1[0]" ,
180+ "vmlal.s32 q5, d8, d1[1]" ,
181+ "vmull.s32 q6, d3, d0[0]" ,
182+ "vmlal.s32 q6, d5, d0[1]" ,
183+ "vmlal.s32 q6, d7, d1[0]" ,
184+ "vmlal.s32 q6, d9, d1[1]" ,
185+ "vshr.s64 q5, q5, 12" ,
186+ "vshr.s64 q6, q6, 12" ,
187+ "vstr.32 s20, [{ret}]" ,
188+ "vstr.32 s22, [{ret}, 4]" ,
189+ "vstr.32 s24, [{ret}, 8]" ,
190+ "vstr.32 s26, [{ret}, 12]" ,
191+ v = in( reg) self . 0 . as_ptr( ) ,
192+ m = in( reg) rhs. 0 . as_ptr( ) ,
193+ ret = in( reg) ret. 0 . as_mut_ptr( ) ,
194+ options( preserves_flags, nostack) ,
195+ ) ;
196+ }
151197 ret
152198 }
153199}
@@ -164,14 +210,40 @@ impl ops::MulAssign<Matrix> for Vectori32<4> {
164210 }
165211}
166212
167- impl < const SIZE : usize > ops:: Mul for Vectori32 < SIZE > {
213+ impl ops:: Mul for Vectori32 < 3 > {
168214 type Output = i32 ;
169215
170216 fn mul ( self , rhs : Self ) -> Self :: Output {
217+ /* Vectorization of
171218 let mut dot = 0;
172- for i in 0 .. SIZE {
173- dot += self [ i ] as i64 * rhs[ i ] as i64 ;
174- }
219+ dot += self[0] as i64 * rhs[0] as i64;
220+ dot += self[1 ] as i64 * rhs[1 ] as i64;
221+ dot += self[2] as i64 * rhs[2] as i64;
175222 (dot >> 12) as i32
223+ */
224+
225+ let v1 = self . 0 . as_ptr ( ) ;
226+ let v2 = rhs. 0 . as_ptr ( ) ;
227+ let mut dot: i32 ;
228+ unsafe {
229+ asm ! (
230+ "vmov.s32 d1, 0" ,
231+ "vmov.s32 d3, 0" ,
232+ "vld1.s32 {{d0}}, [{v1}]!" ,
233+ "vld1.s32 {{d1[0]}}, [{v1}]" ,
234+ "vld1.s32 {{d2}}, [{v2}]!" ,
235+ "vld1.s32 {{d3[0]}}, [{v2}]" ,
236+ "vmull.s32 q2, d0, d2" ,
237+ "vmlal.s32 q2, d1, d3" ,
238+ "vadd.s64 d4, d4, d5" ,
239+ "vshr.s64 d4, d4, 12" ,
240+ "vmov.s32 {dot}, d4[0]" ,
241+ v1 = in( reg) v1,
242+ v2 = in( reg) v2,
243+ dot = out( reg) dot,
244+ options( pure, readonly, preserves_flags, nostack) ,
245+ ) ;
246+ }
247+ dot
176248 }
177249}
0 commit comments