Skip to content

Commit 245fbcd

Browse files
Feat/fused matmul tune (#2726)
1 parent b33bd24 commit 245fbcd

File tree

14 files changed

+525
-60
lines changed

14 files changed

+525
-60
lines changed

Cargo.lock

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
153153
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
154154

155155
### For the main burn branch. ###
156-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
157-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
156+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
157+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
158158
### For local development. ###
159159
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
160160
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

backend-comparison/benches/matmul_fused.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use backend_comparison::persistence::save;
2-
use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor};
2+
use burn::tensor::{
3+
activation::{gelu, relu},
4+
backend::Backend,
5+
Distribution, Shape, Tensor,
6+
};
37
use burn_common::benchmark::{run_benchmark, Benchmark};
48
use derive_new::new;
59

@@ -14,7 +18,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
1418
type Args = (Tensor<B, D>, Tensor<B, D>, Tensor<B, 1>);
1519

1620
fn name(&self) -> String {
17-
"matmul_bias_relu".into()
21+
"matmul_relu_bias_gelu".into()
1822
}
1923

2024
fn shapes(&self) -> Vec<Vec<usize>> {
@@ -23,7 +27,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
2327

2428
fn execute(&self, (lhs, rhs, bias): Self::Args) {
2529
let bias = bias.unsqueeze();
26-
relu(lhs.matmul(rhs) + bias);
30+
gelu(relu(lhs.matmul(rhs)) + bias);
2731
}
2832

2933
fn prepare(&self) -> Self::Args {

crates/burn-fusion/src/stream/context.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,84 @@ pub(crate) struct OperationConverter {
5959
scalar_u8: Vec<u8>,
6060
}
6161

62+
/// Fork of a [context](Context) which owns its data.
63+
pub struct ContextOwned<H> {
64+
tensors: HashMap<TensorId, TensorDescription>,
65+
handles: HandleContainer<H>,
66+
scalar_f32: Vec<f32>,
67+
scalar_f16: Vec<f16>,
68+
scalar_bf16: Vec<bf16>,
69+
scalar_i64: Vec<i64>,
70+
scalar_i32: Vec<i32>,
71+
scalar_i16: Vec<i16>,
72+
scalar_i8: Vec<i8>,
73+
scalar_u64: Vec<u64>,
74+
scalar_u32: Vec<u32>,
75+
scalar_u16: Vec<u16>,
76+
scalar_u8: Vec<u8>,
77+
}
78+
79+
impl<H: Clone> ContextOwned<H> {
80+
/// Convert into [context](Context).
81+
pub fn as_context(&mut self) -> Context<'_, H> {
82+
Context {
83+
tensors: &mut self.tensors,
84+
handles: &mut self.handles,
85+
scalar_f32: &self.scalar_f32,
86+
scalar_f16: &self.scalar_f16,
87+
scalar_bf16: &self.scalar_bf16,
88+
scalar_i64: &self.scalar_i64,
89+
scalar_i32: &self.scalar_i32,
90+
scalar_i16: &self.scalar_i16,
91+
scalar_i8: &self.scalar_i8,
92+
scalar_u64: &self.scalar_u64,
93+
scalar_u32: &self.scalar_u32,
94+
scalar_u16: &self.scalar_u16,
95+
scalar_u8: &self.scalar_u8,
96+
}
97+
}
98+
99+
/// Fork the context again.
100+
pub fn fork(&self) -> ContextOwned<H> {
101+
ContextOwned {
102+
tensors: self.tensors.clone(),
103+
handles: self.handles.fork(),
104+
scalar_f32: self.scalar_f32.clone(),
105+
scalar_f16: self.scalar_f16.clone(),
106+
scalar_bf16: self.scalar_bf16.clone(),
107+
scalar_i64: self.scalar_i64.clone(),
108+
scalar_i32: self.scalar_i32.clone(),
109+
scalar_i16: self.scalar_i16.clone(),
110+
scalar_i8: self.scalar_i8.clone(),
111+
scalar_u64: self.scalar_u64.clone(),
112+
scalar_u32: self.scalar_u32.clone(),
113+
scalar_u16: self.scalar_u16.clone(),
114+
scalar_u8: self.scalar_u8.clone(),
115+
}
116+
}
117+
}
118+
119+
impl<H: Clone> Context<'_, H> {
120+
/// Fork the context into an [owned context](ContextOwned).
121+
pub fn fork(&self) -> ContextOwned<H> {
122+
ContextOwned {
123+
tensors: self.tensors.clone(),
124+
handles: self.handles.fork(),
125+
scalar_f32: self.scalar_f32.clone(),
126+
scalar_f16: self.scalar_f16.clone(),
127+
scalar_bf16: self.scalar_bf16.clone(),
128+
scalar_i64: self.scalar_i64.clone(),
129+
scalar_i32: self.scalar_i32.clone(),
130+
scalar_i16: self.scalar_i16.clone(),
131+
scalar_i8: self.scalar_i8.clone(),
132+
scalar_u64: self.scalar_u64.clone(),
133+
scalar_u32: self.scalar_u32.clone(),
134+
scalar_u16: self.scalar_u16.clone(),
135+
scalar_u8: self.scalar_u8.clone(),
136+
}
137+
}
138+
}
139+
62140
pub(crate) trait RelativeOps {
63141
/// Convert (usually an [`OperationDescription`]) to a relative form.
64142
///

crates/burn-jit/src/fusion/base.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,16 @@ impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
125125
fn optimizations(
126126
device: R::Device,
127127
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
128-
let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
129-
vec![Box::new(ElementWiseBuilder::<R>::new(
128+
vec![
129+
Box::new(ElementWiseBuilder::<R>::new(
130130
device.clone(),
131131
BT::as_elem_native_unchecked().into(),
132-
))];
133-
134-
if cfg!(feature = "fusion-experimental") {
135-
optimizations.push(Box::new(MatmulBuilder::<R>::new(
132+
)),
133+
Box::new(MatmulBuilder::<R>::new(
136134
device.clone(),
137135
BT::as_elem_native_unchecked().into(),
138-
)));
139-
}
140-
141-
optimizations
136+
)),
137+
]
142138
}
143139
}
144140

crates/burn-jit/src/fusion/matmul/builder.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
4747
let rhs = self.builder.input_unhandled(&op.rhs);
4848
let out = self.builder.output_unhandled(&op.out);
4949

50-
self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone()));
50+
self.matmul = Some(FusedMatmul::new(
51+
lhs,
52+
rhs,
53+
out,
54+
op.clone(),
55+
Default::default(),
56+
));
5157
} else {
5258
self.builder.close();
5359
}

crates/burn-jit/src/fusion/matmul/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pub(crate) mod args;
22
pub(crate) mod builder;
33
pub(crate) mod optimization;
44
pub(crate) mod spec;
5+
pub(crate) mod tune;

0 commit comments

Comments
 (0)