Skip to content

Commit c50912e

Browse files
authored
C code tests & avx512f f16 implement (#183)
* test: add tests for c code Signed-off-by: usamoi <usamoi@outlook.com> * fix: relax EPSILON for tests Signed-off-by: usamoi <usamoi@outlook.com> --------- Signed-off-by: usamoi <usamoi@outlook.com>
1 parent 2869fbd commit c50912e

File tree

13 files changed

+276
-31
lines changed

13 files changed

+276
-31
lines changed

.github/workflows/check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ jobs:
101101
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
102102
- name: Test
103103
run: |
104-
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
104+
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture
105105
- name: Install release
106106
run: ./scripts/ci_install.sh
107107
- name: Sqllogictest

Cargo.lock

Lines changed: 23 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/c/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ name = "c"
33
version.workspace = true
44
edition.workspace = true
55

6-
[dependencies]
7-
half = { version = "~2.3", features = ["use-intrinsics"] }
6+
[dev-dependencies]
7+
half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
8+
detect = { path = "../detect" }
9+
rand = "0.8.5"
810

911
[build-dependencies]
1012
cc = "1.0"

crates/c/src/c.c

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
2929
xx = _mm512_fmadd_ph(x, x, xx);
3030
yy = _mm512_fmadd_ph(y, y, yy);
3131
}
32-
return (float)(_mm512_reduce_add_ph(xy) /
33-
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
32+
{
33+
float rxy = _mm512_reduce_add_ph(xy);
34+
float rxx = _mm512_reduce_add_ph(xx);
35+
float ryy = _mm512_reduce_add_ph(yy);
36+
return rxy / sqrt(rxx * ryy);
37+
}
3438
}
3539

3640
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
@@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
7478
return (float)_mm512_reduce_add_ph(dd);
7579
}
7680

81+
__attribute__((target("arch=x86-64-v4"))) extern float
82+
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
83+
__m512 xy = _mm512_set1_ps(0);
84+
__m512 xx = _mm512_set1_ps(0);
85+
__m512 yy = _mm512_set1_ps(0);
86+
87+
while (n >= 16) {
88+
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
89+
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
90+
a += 16, b += 16, n -= 16;
91+
xy = _mm512_fmadd_ps(x, y, xy);
92+
xx = _mm512_fmadd_ps(x, x, xx);
93+
yy = _mm512_fmadd_ps(y, y, yy);
94+
}
95+
if (n > 0) {
96+
__mmask16 mask = _bzhi_u32(0xFFFF, n);
97+
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
98+
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
99+
xy = _mm512_fmadd_ps(x, y, xy);
100+
xx = _mm512_fmadd_ps(x, x, xx);
101+
yy = _mm512_fmadd_ps(y, y, yy);
102+
}
103+
{
104+
float rxy = _mm512_reduce_add_ps(xy);
105+
float rxx = _mm512_reduce_add_ps(xx);
106+
float ryy = _mm512_reduce_add_ps(yy);
107+
return rxy / sqrt(rxx * ryy);
108+
}
109+
}
110+
111+
__attribute__((target("arch=x86-64-v4"))) extern float
112+
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
113+
__m512 xy = _mm512_set1_ps(0);
114+
115+
while (n >= 16) {
116+
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
117+
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
118+
a += 16, b += 16, n -= 16;
119+
xy = _mm512_fmadd_ps(x, y, xy);
120+
}
121+
if (n > 0) {
122+
__mmask16 mask = _bzhi_u32(0xFFFF, n);
123+
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
124+
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
125+
xy = _mm512_fmadd_ps(x, y, xy);
126+
}
127+
return _mm512_reduce_add_ps(xy);
128+
}
129+
130+
__attribute__((target("arch=x86-64-v4"))) extern float
131+
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
132+
__m512 dd = _mm512_set1_ps(0);
133+
134+
while (n >= 16) {
135+
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
136+
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
137+
a += 16, b += 16, n -= 16;
138+
__m512 d = _mm512_sub_ps(x, y);
139+
dd = _mm512_fmadd_ps(d, d, dd);
140+
}
141+
if (n > 0) {
142+
__mmask16 mask = _bzhi_u32(0xFFFF, n);
143+
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
144+
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
145+
__m512 d = _mm512_sub_ps(x, y);
146+
dd = _mm512_fmadd_ps(d, d, dd);
147+
}
148+
return _mm512_reduce_add_ps(dd);
149+
}
150+
77151
__attribute__((target("arch=x86-64-v3"))) extern float
78152
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
79153
float xy = 0;

crates/c/src/c.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
77
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
88
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
9+
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
10+
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
11+
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
912
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
1013
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
1114
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);

crates/c/src/c.rs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,10 @@ extern "C" {
44
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
55
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
66
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
7+
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
8+
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
9+
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
710
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
811
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
912
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
1013
}
11-
12-
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
13-
// However C compilers links `__extendhfsf2` with floating calling convention.
14-
// The code should be removed once Rust offically supports `f16`.
15-
16-
#[cfg(target_arch = "x86_64")]
17-
#[no_mangle]
18-
#[linkage = "external"]
19-
extern "C" fn __extendhfsf2(f: f64) -> f32 {
20-
unsafe {
21-
let f: half::f16 = std::mem::transmute_copy(&f);
22-
f.to_f32()
23-
}
24-
}

crates/c/tests/x86_64.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#![cfg(target_arch = "x86_64")]
2+
3+
#[test]
4+
fn test_v_f16_cosine() {
5+
const EPSILON: f32 = f16::EPSILON.to_f32_const();
6+
use half::f16;
7+
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
8+
let mut xy = 0.0f32;
9+
let mut xx = 0.0f32;
10+
let mut yy = 0.0f32;
11+
for i in 0..n {
12+
let x = a.add(i).cast::<f16>().read().to_f32();
13+
let y = b.add(i).cast::<f16>().read().to_f32();
14+
xy += x * y;
15+
xx += x * x;
16+
yy += y * y;
17+
}
18+
xy / (xx * yy).sqrt()
19+
}
20+
let n = 4000;
21+
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
22+
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
23+
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
24+
if detect::x86_64::detect_avx512fp16() {
25+
println!("detected avx512fp16");
26+
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
27+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
28+
} else {
29+
println!("detected no avx512fp16, skipped");
30+
}
31+
if detect::x86_64::detect_v4() {
32+
println!("detected v4");
33+
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
34+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
35+
} else {
36+
println!("detected no v4, skipped");
37+
}
38+
if detect::x86_64::detect_v3() {
39+
println!("detected v3");
40+
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
41+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
42+
} else {
43+
println!("detected no v3, skipped");
44+
}
45+
}
46+
47+
#[test]
48+
fn test_v_f16_dot() {
49+
const EPSILON: f32 = 1.0f32;
50+
use half::f16;
51+
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
52+
let mut xy = 0.0f32;
53+
for i in 0..n {
54+
let x = a.add(i).cast::<f16>().read().to_f32();
55+
let y = b.add(i).cast::<f16>().read().to_f32();
56+
xy += x * y;
57+
}
58+
xy
59+
}
60+
let n = 4000;
61+
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
62+
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
63+
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
64+
if detect::x86_64::detect_avx512fp16() {
65+
println!("detected avx512fp16");
66+
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
67+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
68+
} else {
69+
println!("detected no avx512fp16, skipped");
70+
}
71+
if detect::x86_64::detect_v4() {
72+
println!("detected v4");
73+
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
74+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
75+
} else {
76+
println!("detected no v4, skipped");
77+
}
78+
if detect::x86_64::detect_v3() {
79+
println!("detected v3");
80+
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
81+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
82+
} else {
83+
println!("detected no v3, skipped");
84+
}
85+
}
86+
87+
#[test]
88+
fn test_v_f16_sl2() {
89+
const EPSILON: f32 = 1.0f32;
90+
use half::f16;
91+
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
92+
let mut dd = 0.0f32;
93+
for i in 0..n {
94+
let x = a.add(i).cast::<f16>().read().to_f32();
95+
let y = b.add(i).cast::<f16>().read().to_f32();
96+
let d = x - y;
97+
dd += d * d;
98+
}
99+
dd
100+
}
101+
let n = 4000;
102+
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
103+
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
104+
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
105+
if detect::x86_64::detect_avx512fp16() {
106+
println!("detected avx512fp16");
107+
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
108+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
109+
} else {
110+
println!("detected no avx512fp16, skipped");
111+
}
112+
if detect::x86_64::detect_v4() {
113+
println!("detected v4");
114+
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
115+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
116+
} else {
117+
println!("detected no v4, skipped");
118+
}
119+
if detect::x86_64::detect_v3() {
120+
println!("detected v3");
121+
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
122+
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
123+
} else {
124+
println!("detected no v3, skipped");
125+
}
126+
}

crates/detect/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "detect"
3+
version.workspace = true
4+
edition.workspace = true
5+
6+
[dependencies]
7+
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
8+
ctor = "0.2.6"
File renamed without changes.

crates/service/src/utils/detect/x86_64.rs renamed to crates/detect/src/x86_64.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fn ctor_v4() {
3333
ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
3434
}
3535

36-
pub fn _detect_v4() -> bool {
36+
pub fn detect_v4() -> bool {
3737
ATOMIC_V4.load(Ordering::Relaxed)
3838
}
3939

0 commit comments

Comments
 (0)