Skip to content

Commit 370adb0

Browse files
fix: change some naming
1 parent e2bd9b1 commit 370adb0

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
### 1. 算子:SiLU函数(10分)
2121

22-
请在`src/operators.rs`中实现SiLU算子,其公式为:
22+
请在`src/operators.rs`中实现SwiGLU算子,其公式为:
2323

2424
$$
2525
y=silu(x) × y
@@ -85,8 +85,8 @@ $$
8585
hidden = rms_norm(residual)
8686
gate = hidden @ gate_weight.T
8787
up = hidden @ up_weight.T
88-
itermediate = gate * sigmoid(gate) * up ## silu
89-
output = itermediate @ down_weight.T
88+
act = gate * sigmoid(gate) * up ## SwiGLU
89+
output = act @ down_weight.T
9090
residual = output + residual
9191
```
9292

@@ -149,9 +149,9 @@ V = cat(V_cache, V)
149149
### 以下是你需要实现的部分
150150
score = Q @ K.T / sqrt(dim)
151151
attn = softmax(score)
152-
x = attn @ V
153-
x = x @ O_weight.T
154-
residual = x + residual
152+
attn_V = attn @ V
153+
out = attn_V @ O_weight.T
154+
residual = out + residual
155155
```
156156

157157
Self-Attention的调试是很困难的。这里推荐大家使用pytorch来辅助调试。各位可以用transformers库(使用llama模型代码)来加载模型并运行,逐层检查中间张量结果。

src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ fn main() {
2020
let output_ids = llama.generate(
2121
input_ids,
2222
500,
23-
0.9,
24-
4,
23+
0.8,
24+
30,
2525
1.,
2626
);
2727
println!("{}", tokenizer.decode(&output_ids, true).unwrap());

src/operators.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon:
7474
todo!("实现 rms_norm,计算前做一些必要的检查会帮助你后续调试")
7575
}
7676

77-
// y = sigmoid(x) * x * y
77+
// y = silu(x) * y
7878
// hint: this is an element-wise operation
79-
pub fn silu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
79+
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
8080
// let len = y.size();
8181
// assert!(len == x.size());
8282

@@ -176,7 +176,7 @@ pub fn random_sample(x: &Tensor<f32>, top_p: f32, top_k: u32, temperature: f32)
176176
fn test_silu() {
177177
let mut y = Tensor::<f32>::new(vec![2., 3., 4.], &vec![1, 3]);
178178
let x = Tensor::<f32>::new(vec![1., 2., 3.], &vec![1, 3]);
179-
silu(&mut y, &x);
179+
swiglu(&mut y, &x);
180180
assert!(y.close_to(
181181
&Tensor::<f32>::new(vec![1.4621172, 5.2847824, 11.43089], &vec![1, 3]),
182182
1e-3

0 commit comments

Comments
 (0)