Skip to content

Commit 23b71e8

Browse files
committed
Feat: start adding cublas gemm
1 parent 1709318 commit 23b71e8

File tree

7 files changed

+369
-3
lines changed

7 files changed

+369
-3
lines changed

crates/blastoff/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ bitflags = "1.3.2"
1010
cublas_sys = { version = "0.1", path = "../cublas_sys" }
1111
cust = { version = "0.3", path = "../cust", features = ["impl_num_complex"] }
1212
num-complex = "0.4.0"
13+
half = { version = "1.8.0", optional = true }
1314

1415
[package.metadata.docs.rs]
15-
rustdoc-args = ["--html-in-header", "katex-header.html"]
16+
rustdoc-args = ["--html-in-header", "katex-header.html", "--cfg", "docsrs"]

crates/blastoff/src/context.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ bitflags::bitflags! {
6363
/// - [Construct the modified givens rotation matrix that zeros the second entry of a vector<span style="float:right;">`rotmg`</span>](CublasContext::rotmg)
6464
/// - [Scale a vector by a scalar <span style="float:right;">`scal`</span>](CublasContext::scal)
6565
/// - [Swap two vectors <span style="float:right;">`swap`</span>](CublasContext::swap)
66+
///
67+
/// ## Level 3 Methods (Matrix-based operations)
68+
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
6669
#[derive(Debug)]
6770
pub struct CublasContext {
6871
pub(crate) raw: sys::v2::cublasHandle_t,

crates/blastoff/src/level3.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
use crate::{
2+
context::CublasContext,
3+
error::{Error, ToResult},
4+
raw::GemmOps,
5+
GemmDatatype, MatrixOp,
6+
};
7+
use cust::memory::{GpuBox, GpuBuffer};
8+
use cust::stream::Stream;
9+
10+
type Result<T = (), E = Error> = std::result::Result<T, E>;
11+
12+
#[track_caller]
13+
fn check_gemm<T: GemmDatatype + GemmOps>(
14+
m: usize,
15+
n: usize,
16+
k: usize,
17+
a: &impl GpuBuffer<T>,
18+
lda: usize,
19+
op_a: MatrixOp,
20+
b: &impl GpuBuffer<T>,
21+
ldb: usize,
22+
op_b: MatrixOp,
23+
c: &mut impl GpuBuffer<T>,
24+
ldc: usize,
25+
) {
26+
assert!(m > 0 && n > 0 && k > 0, "m, n, and k must be at least 1");
27+
28+
if op_a == MatrixOp::None {
29+
assert!(lda >= m, "lda must be at least m if op_a is None");
30+
31+
assert!(
32+
a.len() >= lda * k,
33+
"matrix A's length must be at least lda * k"
34+
);
35+
} else {
36+
assert!(lda >= k, "lda must be at least k if op_a is None");
37+
38+
assert!(
39+
a.len() >= lda * m,
40+
"matrix A's length must be at least lda * m"
41+
);
42+
}
43+
44+
if op_b == MatrixOp::None {
45+
assert!(ldb >= k, "ldb must be at least k if op_b is None");
46+
47+
assert!(
48+
b.len() >= ldb * n,
49+
"matrix B's length must be at least ldb * n"
50+
);
51+
} else {
52+
assert!(ldb >= n, "ldb must be at least n if op_b is None");
53+
54+
assert!(
55+
a.len() >= ldb * k,
56+
"matrix B's length must be at least ldb * k"
57+
);
58+
}
59+
60+
assert!(ldc >= m, "ldc must be at least m");
61+
62+
assert!(
63+
c.len() >= ldc * n,
64+
"matrix C's length must be at least ldc * n"
65+
);
66+
}
67+
68+
impl CublasContext {
69+
/// Generic Matrix Multiplication.
70+
///
71+
/// # Panics
72+
///
73+
/// Panics if any of the following conditions are not met:
74+
/// - `m > 0 && n > 0 && k > 0`
75+
/// - `lda >= m` if `op_a == MatrixOp::None`
76+
/// - `a.len() >= lda * k` if `op_a == MatrixOp::None`
77+
/// - `lda >= k` if `op_a == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose`
78+
/// - `a.len() >= lda * m` if `op_a == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose`
79+
/// - `ldb >= k` if `op_b == MatrixOp::None`
80+
/// - `b.len() >= ldb * n` if `op_b == MatrixOp::None`
81+
/// - `ldb >= n` if `op_b == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose`
82+
/// - `b.len() >= ldb * k` if `op_b == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose`
83+
/// - `ldc >= m`
84+
/// - `c.len() >= ldc * n`
85+
///
86+
/// # Errors
87+
///
88+
/// Returns an error if the kernel execution failed or the selected precision is `half` and the device does not support half precision.
89+
#[track_caller]
90+
pub fn gemm<T: GemmDatatype + GemmOps>(
91+
&mut self,
92+
stream: &Stream,
93+
m: usize,
94+
n: usize,
95+
k: usize,
96+
alpha: &impl GpuBox<T>,
97+
a: &impl GpuBuffer<T>,
98+
lda: usize,
99+
op_a: MatrixOp,
100+
beta: &impl GpuBox<T>,
101+
b: &impl GpuBuffer<T>,
102+
ldb: usize,
103+
op_b: MatrixOp,
104+
c: &mut impl GpuBuffer<T>,
105+
ldc: usize,
106+
) -> Result {
107+
check_gemm(m, n, k, a, lda, op_a, b, ldb, op_b, c, ldc);
108+
109+
let transa = op_a.to_raw();
110+
let transb = op_b.to_raw();
111+
112+
self.with_stream(stream, |ctx| unsafe {
113+
Ok(T::gemm(
114+
ctx.raw,
115+
transa,
116+
transb,
117+
m as i32,
118+
n as i32,
119+
k as i32,
120+
alpha.as_device_ptr().as_ptr(),
121+
a.as_device_ptr().as_ptr(),
122+
lda as i32,
123+
b.as_device_ptr().as_ptr(),
124+
ldb as i32,
125+
beta.as_device_ptr().as_ptr(),
126+
c.as_device_ptr().as_mut_ptr(),
127+
ldc as i32,
128+
)
129+
.to_result()?)
130+
})
131+
}
132+
}

crates/blastoff/src/lib.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//! [`amin`](crate::context::CublasContext::amin) returns a 1-based index.**
99
1010
#![allow(clippy::too_many_arguments)]
11+
#![cfg_attr(docsrs, feature(doc_cfg))]
1112

1213
pub use cublas_sys as sys;
1314
use num_complex::{Complex32, Complex64};
@@ -17,8 +18,23 @@ pub use context::*;
1718
mod context;
1819
pub mod error;
1920
mod level1;
21+
mod level3;
2022
pub mod raw;
2123

24+
/// A possible datatype for a generic matrix mul operation. This is just [`BlasDatatype`] except optionally
25+
/// containing `f16` with the `half` feature.
26+
pub trait GemmDatatype: private::Sealed + cust::memory::DeviceCopy {}
27+
28+
#[cfg(feature = "half")]
29+
impl private::Sealed for half::f16 {}
30+
#[cfg_attr(docsrs, doc(cfg(feature = "half")))]
31+
#[cfg(feature = "half")]
32+
impl GemmDatatype for half::f16 {}
33+
impl GemmDatatype for f32 {}
34+
impl GemmDatatype for f64 {}
35+
impl GemmDatatype for Complex32 {}
36+
impl GemmDatatype for Complex64 {}
37+
2238
pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
2339
/// The corresponding float type. For complex numbers this means their backing
2440
/// precision, and for floats it is just themselves.
@@ -74,3 +90,32 @@ pub(crate) mod private {
7490
impl Sealed for Complex32 {}
7591
impl Sealed for Complex64 {}
7692
}
93+
94+
/// An optional operation to apply to a matrix before a matrix operation. This includes
95+
/// no operation, transpose, or conjugate transpose.
96+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97+
pub enum MatrixOp {
98+
/// No operation, leave the matrix as is. This is the default.
99+
None,
100+
/// Transpose the matrix in place.
101+
Transpose,
102+
/// Conjugate transpose the matrix in place.
103+
ConjugateTranspose,
104+
}
105+
106+
impl Default for MatrixOp {
107+
fn default() -> Self {
108+
MatrixOp::None
109+
}
110+
}
111+
112+
impl MatrixOp {
113+
/// Returns the corresponding `cublasOperation_t` for this operation.
114+
pub fn to_raw(self) -> sys::v2::cublasOperation_t {
115+
match self {
116+
MatrixOp::None => sys::v2::cublasOperation_t::CUBLAS_OP_N,
117+
MatrixOp::Transpose => sys::v2::cublasOperation_t::CUBLAS_OP_T,
118+
MatrixOp::ConjugateTranspose => sys::v2::cublasOperation_t::CUBLAS_OP_C,
119+
}
120+
}
121+
}

crates/blastoff/src/raw/level1.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use std::os::raw::c_int;
2-
31
use crate::{sys::v2::*, BlasDatatype};
42
use num_complex::{Complex32, Complex64};
3+
use std::os::raw::c_int;
54

65
pub trait Level1: BlasDatatype {
76
unsafe fn amax(

0 commit comments

Comments
 (0)