Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 4615805

Browse files
miguelrazworkingjubilee
authored andcommitted
add remainder dot_product and cleanup
cleanup dot_product and README.md
1 parent c08a4d1 commit 4615805

File tree

2 files changed

+95
-19
lines changed

2 files changed

+95
-19
lines changed

crates/core_simd/examples/README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,4 @@ Run the tests with the command
1010
cargo run --example dot_product
1111
```
1212

13-
and the benchmarks via the command
14-
15-
```
16-
cargo run --example --benchmark ???
17-
```
18-
19-
and measure the timings on your local system.
13+
and verify the code for `dot_product.rs` on your machine.
Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,113 @@
11
// Code taken from the `packed_simd` crate
22
// Run this code with `cargo test --example dot_product`
3+
//use std::iter::zip;
4+
35
#![feature(array_chunks)]
6+
#![feature(slice_as_chunks)]
7+
// Add these imports to use the stdsimd library
8+
#![feature(portable_simd)]
49
use core_simd::*;
510

6-
/// This is your barebones dot product implementation:
7-
/// Take 2 vectors, multiply them element wise and *then*
8-
/// add up the result. In the next example we will see if there
9-
/// is any difference to adding as we go along multiplying.
11+
// This is your barebones dot product implementation:
12+
// Take 2 vectors, multiply them element wise and *then*
13+
// go along the resulting array and add up the result.
14+
// In the next example we will see if there
15+
// is any difference to adding and multiplying in tandem.
1016
pub fn dot_prod_0(a: &[f32], b: &[f32]) -> f32 {
1117
assert_eq!(a.len(), b.len());
1218

13-
a.iter()
14-
.zip(b.iter())
15-
.map(|a, b| a * b)
16-
.sum()
19+
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1720
}
1821

22+
// When dealing with SIMD, it is very important to think about the amount
23+
// of data movement and when it happens. We're going over simple computation examples here, and yet
24+
// it is not trivial to understand what may or may not contribute to performance
25+
// changes. Eventually, you will need tools to inspect the generated assembly and confirm your
26+
// hypothesis and benchmarks - we will mention them later on.
27+
// With the use of `fold`, we're doing a multiplication,
28+
// and then adding it to the sum, one element from both vectors at a time.
1929
pub fn dot_prod_1(a: &[f32], b: &[f32]) -> f32 {
2030
assert_eq!(a.len(), b.len());
2131
a.iter()
22-
.zip(b.iter())
23-
.fold(0.0, |a, b| a * b)
32+
.zip(b.iter())
33+
.fold(0.0, |a, zipped| a + zipped.0 * zipped.1)
2434
}
2535

36+
// We now move on to the SIMD implementations: notice the following constructs:
37+
// `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
38+
// `f32x4::from_array`: construct the SIMD vector from a slice
39+
// `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
40+
// This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
41+
// and then add those with `sum()`. This is suboptimal.
42+
// TODO: ASCII diagrams
2643
pub fn dot_prod_simd_0(a: &[f32], b: &[f32]) -> f32 {
2744
assert_eq!(a.len(), b.len());
28-
2945
// TODO handle remainder when a.len() % 4 != 0
3046
a.array_chunks::<4>()
3147
.map(|&a| f32x4::from_array(a))
3248
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
33-
.map(|(a, b)| (a * b).horizontal_sum())
49+
.map(|(a, b)| (a * b).reduce_sum())
3450
.sum()
3551
}
3652

53+
// There's some simple ways to improve the previous code:
54+
// 1. Make a `zero` `f32x4` SIMD vector that we will be accumulating into
55+
// So that there is only one `sum()` reduction when the last `f32x4` has been processed
56+
// 2. Exploit Fused Multiply Add so that the multiplication, addition and sinking into the reduciton
57+
// happen in the same step.
58+
// If the arrays are large, minimizing the data shuffling will lead to great perf.
59+
// If the arrays are small, handling the remainder elements when the length isn't a multiple of 4
60+
// Can become a problem.
61+
pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {
62+
assert_eq!(a.len(), b.len());
63+
// TODO handle remainder when a.len() % 4 != 0
64+
a.array_chunks::<4>()
65+
.map(|&a| f32x4::from_array(a))
66+
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
67+
.fold(f32x4::splat(0.0), |acc, zipped| acc + zipped.0 * zipped.1)
68+
.reduce_sum()
69+
}
70+
71+
// A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
72+
// available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for.
73+
use std_float::StdFloat;
74+
pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
75+
assert_eq!(a.len(), b.len());
76+
// TODO handle remainder when a.len() % 4 != 0
77+
let mut res = f32x4::splat(0.0);
78+
a.array_chunks::<4>()
79+
.map(|&a| f32x4::from_array(a))
80+
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
81+
.for_each(|(a, b)| {
82+
res = a.mul_add(b, res);
83+
});
84+
res.reduce_sum()
85+
}
86+
87+
// Finally, we will write the same operation but handling the loop remainder.
88+
const LANES: usize = 4;
89+
pub fn dot_prod_simd_3(a: &[f32], b: &[f32]) -> f32 {
90+
assert_eq!(a.len(), b.len());
91+
92+
let (a_extra, a_chunks) = a.as_rchunks();
93+
let (b_extra, b_chunks) = b.as_rchunks();
94+
95+
// These are always true, but for emphasis:
96+
assert_eq!(a_chunks.len(), b_chunks.len());
97+
assert_eq!(a_extra.len(), b_extra.len());
98+
99+
let mut sums = [0.0; LANES];
100+
for ((x, y), d) in std::iter::zip(a_extra, b_extra).zip(&mut sums) {
101+
*d = x * y;
102+
}
103+
104+
let mut sums = f32x4::from_array(sums);
105+
std::iter::zip(a_chunks, b_chunks).for_each(|(x, y)| {
106+
sums += f32x4::from_array(*x) * f32x4::from_array(*y);
107+
});
108+
109+
sums.reduce_sum()
110+
}
37111
fn main() {
38112
// Empty main to make cargo happy
39113
}
@@ -45,10 +119,18 @@ mod tests {
45119
use super::*;
46120
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
47121
let b: Vec<f32> = vec![-8.0, -7.0, -6.0, -5.0, 4.0, 3.0, 2.0, 1.0];
122+
let x: Vec<f32> = [0.5; 1003].to_vec();
123+
let y: Vec<f32> = [2.0; 1003].to_vec();
48124

125+
// Basic check
49126
assert_eq!(0.0, dot_prod_0(&a, &b));
50127
assert_eq!(0.0, dot_prod_1(&a, &b));
51128
assert_eq!(0.0, dot_prod_simd_0(&a, &b));
52129
assert_eq!(0.0, dot_prod_simd_1(&a, &b));
130+
assert_eq!(0.0, dot_prod_simd_2(&a, &b));
131+
assert_eq!(0.0, dot_prod_simd_3(&a, &b));
132+
133+
// We can handle vectors that are non-multiples of 4
134+
assert_eq!(1003.0, dot_prod_simd_3(&x, &y));
53135
}
54136
}

0 commit comments

Comments
 (0)