1
1
// Code taken from the `packed_simd` crate
2
2
// Run this code with `cargo test --example dot_product`
3
+ //use std::iter::zip;
4
+
3
5
#![ feature( array_chunks) ]
6
+ #![ feature( slice_as_chunks) ]
7
+ // Add these imports to use the stdsimd library
8
+ #![ feature( portable_simd) ]
4
9
use core_simd:: * ;
5
10
6
- /// This is your barebones dot product implementation:
7
- /// Take 2 vectors, multiply them element wise and *then*
8
- /// add up the result. In the next example we will see if there
9
- /// is any difference to adding as we go along multiplying.
11
+ // This is your barebones dot product implementation:
12
+ // Take 2 vectors, multiply them element wise and *then*
13
+ // go along the resulting array and add up the result.
14
+ // In the next example we will see if there
15
+ // is any difference to adding and multiplying in tandem.
10
16
pub fn dot_prod_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
11
17
assert_eq ! ( a. len( ) , b. len( ) ) ;
12
18
13
- a. iter ( )
14
- . zip ( b. iter ( ) )
15
- . map ( |a, b| a * b)
16
- . sum ( )
19
+ a. iter ( ) . zip ( b. iter ( ) ) . map ( |( a, b) | a * b) . sum ( )
17
20
}
18
21
22
+ // When dealing with SIMD, it is very important to think about the amount
23
+ // of data movement and when it happens. We're going over simple computation examples here, and yet
24
+ // it is not trivial to understand what may or may not contribute to performance
25
+ // changes. Eventually, you will need tools to inspect the generated assembly and confirm your
26
+ // hypothesis and benchmarks - we will mention them later on.
27
+ // With the use of `fold`, we're doing a multiplication,
28
+ // and then adding it to the sum, one element from both vectors at a time.
19
29
pub fn dot_prod_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
20
30
assert_eq ! ( a. len( ) , b. len( ) ) ;
21
31
a. iter ( )
22
- . zip ( b. iter ( ) )
23
- . fold ( 0.0 , |a, b | a * b )
32
+ . zip ( b. iter ( ) )
33
+ . fold ( 0.0 , |a, zipped | a + zipped . 0 * zipped . 1 )
24
34
}
25
35
36
+ // We now move on to the SIMD implementations: notice the following constructs:
37
+ // `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
38
+ // `f32x4::from_array`: construct the SIMD vector from a slice
39
+ // `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
40
+ // This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
41
+ // and then add those with `sum()`. This is suboptimal.
42
+ // TODO: ASCII diagrams
26
43
pub fn dot_prod_simd_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
27
44
assert_eq ! ( a. len( ) , b. len( ) ) ;
28
-
29
45
// TODO handle remainder when a.len() % 4 != 0
30
46
a. array_chunks :: < 4 > ( )
31
47
. map ( |& a| f32x4:: from_array ( a) )
32
48
. zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
33
- . map ( |( a, b) | ( a * b) . horizontal_sum ( ) )
49
+ . map ( |( a, b) | ( a * b) . reduce_sum ( ) )
34
50
. sum ( )
35
51
}
36
52
53
+ // There's some simple ways to improve the previous code:
54
+ // 1. Make a `zero` `f32x4` SIMD vector that we will be accumulating into
55
+ // So that there is only one `sum()` reduction when the last `f32x4` has been processed
56
+ // 2. Exploit Fused Multiply Add so that the multiplication, addition and sinking into the reduciton
57
+ // happen in the same step.
58
+ // If the arrays are large, minimizing the data shuffling will lead to great perf.
59
+ // If the arrays are small, handling the remainder elements when the length isn't a multiple of 4
60
+ // Can become a problem.
61
+ pub fn dot_prod_simd_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
62
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
63
+ // TODO handle remainder when a.len() % 4 != 0
64
+ a. array_chunks :: < 4 > ( )
65
+ . map ( |& a| f32x4:: from_array ( a) )
66
+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
67
+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| acc + zipped. 0 * zipped. 1 )
68
+ . reduce_sum ( )
69
+ }
70
+
71
+ // A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
72
+ // available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for.
73
+ use std_float:: StdFloat ;
74
+ pub fn dot_prod_simd_2 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
75
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
76
+ // TODO handle remainder when a.len() % 4 != 0
77
+ let mut res = f32x4:: splat ( 0.0 ) ;
78
+ a. array_chunks :: < 4 > ( )
79
+ . map ( |& a| f32x4:: from_array ( a) )
80
+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
81
+ . for_each ( |( a, b) | {
82
+ res = a. mul_add ( b, res) ;
83
+ } ) ;
84
+ res. reduce_sum ( )
85
+ }
86
+
87
+ // Finally, we will write the same operation but handling the loop remainder.
88
+ const LANES : usize = 4 ;
89
+ pub fn dot_prod_simd_3 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
90
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
91
+
92
+ let ( a_extra, a_chunks) = a. as_rchunks ( ) ;
93
+ let ( b_extra, b_chunks) = b. as_rchunks ( ) ;
94
+
95
+ // These are always true, but for emphasis:
96
+ assert_eq ! ( a_chunks. len( ) , b_chunks. len( ) ) ;
97
+ assert_eq ! ( a_extra. len( ) , b_extra. len( ) ) ;
98
+
99
+ let mut sums = [ 0.0 ; LANES ] ;
100
+ for ( ( x, y) , d) in std:: iter:: zip ( a_extra, b_extra) . zip ( & mut sums) {
101
+ * d = x * y;
102
+ }
103
+
104
+ let mut sums = f32x4:: from_array ( sums) ;
105
+ std:: iter:: zip ( a_chunks, b_chunks) . for_each ( |( x, y) | {
106
+ sums += f32x4:: from_array ( * x) * f32x4:: from_array ( * y) ;
107
+ } ) ;
108
+
109
+ sums. reduce_sum ( )
110
+ }
37
111
fn main ( ) {
38
112
// Empty main to make cargo happy
39
113
}
@@ -45,10 +119,18 @@ mod tests {
45
119
use super :: * ;
46
120
let a: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ;
47
121
let b: Vec < f32 > = vec ! [ -8.0 , -7.0 , -6.0 , -5.0 , 4.0 , 3.0 , 2.0 , 1.0 ] ;
122
+ let x: Vec < f32 > = [ 0.5 ; 1003 ] . to_vec ( ) ;
123
+ let y: Vec < f32 > = [ 2.0 ; 1003 ] . to_vec ( ) ;
48
124
125
+ // Basic check
49
126
assert_eq ! ( 0.0 , dot_prod_0( & a, & b) ) ;
50
127
assert_eq ! ( 0.0 , dot_prod_1( & a, & b) ) ;
51
128
assert_eq ! ( 0.0 , dot_prod_simd_0( & a, & b) ) ;
52
129
assert_eq ! ( 0.0 , dot_prod_simd_1( & a, & b) ) ;
130
+ assert_eq ! ( 0.0 , dot_prod_simd_2( & a, & b) ) ;
131
+ assert_eq ! ( 0.0 , dot_prod_simd_3( & a, & b) ) ;
132
+
133
+ // We can handle vectors that are non-multiples of 4
134
+ assert_eq ! ( 1003.0 , dot_prod_simd_3( & x, & y) ) ;
53
135
}
54
136
}
0 commit comments