Skip to content

Commit 4bf289d

Browse files
kozistrreach-vbcebtenzzre
authored
Support NomicBert MoE (#596)
Co-authored-by: reach-vb <reach-vb@users.noreply.huggingface.co> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
1 parent 5d632d9 commit 4bf289d

10 files changed

+6609
-43
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Below are some examples of the currently supported models:
8181
| 49 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
8282
| N/A | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
8383
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
84+
| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
8485
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
8586
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
8687
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |

backends/candle/src/layers/linear.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ pub enum HiddenAct {
1111
Swiglu,
1212
}
1313

14+
impl HiddenAct {
15+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
16+
match self {
17+
Self::Gelu => x.gelu(),
18+
Self::Relu => x.relu(),
19+
Self::Swiglu => candle_nn::ops::swiglu(x),
20+
}
21+
}
22+
}
23+
1424
#[derive(Debug)]
1525
pub struct Linear {
1626
weight: Tensor,

backends/candle/src/models/flash_nomic.rs

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::flash_attn::flash_attn_varlen;
22
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
3-
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
3+
use crate::models::nomic::{NomicBertEmbeddings, NomicMLP};
44
use crate::models::{Model, NomicConfig};
55
use candle::{DType, Device, IndexOp, Result, Tensor, D};
66
use candle_nn::VarBuilder;
@@ -25,16 +25,25 @@ impl NomicAttention {
2525
let attention_head_size = config.n_embd / config.n_head;
2626
let hidden_size = config.n_embd;
2727

28-
let qkv_weight = vb.pp("Wqkv").get(
29-
(3 * num_attention_heads * attention_head_size, hidden_size),
30-
"weight",
31-
)?;
32-
let qkv_linear = Linear::new(qkv_weight, None, None);
28+
let qkv_dim = 3 * num_attention_heads * attention_head_size;
29+
30+
let qkv_weight = vb.pp("Wqkv").get((qkv_dim, hidden_size), "weight")?;
31+
let qkv_bias = if config.qkv_proj_bias {
32+
Some(vb.pp("Wqkv").get((qkv_dim,), "bias")?)
33+
} else {
34+
None
35+
};
36+
let qkv_linear = Linear::new(qkv_weight, qkv_bias, None);
3337

3438
let out_proj_weight = vb
3539
.pp("out_proj")
3640
.get((hidden_size, hidden_size), "weight")?;
37-
let out_proj = Linear::new(out_proj_weight, None, None);
41+
let out_proj_bias = if config.qkv_proj_bias {
42+
Some(vb.pp("out_proj").get((hidden_size,), "bias")?)
43+
} else {
44+
None
45+
};
46+
let out_proj = Linear::new(out_proj_weight, out_proj_bias, None);
3847

3948
let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32;
4049

@@ -93,17 +102,18 @@ impl NomicAttention {
93102

94103
struct NomicBertBlock {
95104
attention: NomicAttention,
96-
mlp: NomicBertGatedMLP,
105+
mlp: NomicMLP,
97106
post_attention_layer_norm: LayerNorm,
98107
output_layer_norm: LayerNorm,
99108

100109
span: tracing::Span,
101110
}
102111

103112
impl NomicBertBlock {
104-
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
113+
pub fn load(vb: VarBuilder, index: usize, config: &NomicConfig) -> Result<Self> {
105114
let attention = NomicAttention::load(vb.pp("attn"), config)?;
106-
let mlp = NomicBertGatedMLP::load(vb.pp("mlp"), config)?;
115+
116+
let mlp = NomicMLP::load(vb.pp("mlp"), index, config)?;
107117

108118
let post_attention_layer_norm =
109119
LayerNorm::load(vb.pp("norm1"), config.n_embd, config.layer_norm_epsilon)?;
@@ -132,6 +142,7 @@ impl NomicBertBlock {
132142
let attn_output = self
133143
.attention
134144
.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;
145+
135146
let hidden_states = self
136147
.post_attention_layer_norm
137148
.forward(&hidden_states, Some(&attn_output))?;
@@ -145,13 +156,14 @@ impl NomicBertBlock {
145156

146157
struct NomicBertEncoder {
147158
layers: Vec<NomicBertBlock>,
159+
148160
span: tracing::Span,
149161
}
150162

151163
impl NomicBertEncoder {
152164
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
153165
let layers = (0..config.n_layer)
154-
.map(|index| NomicBertBlock::load(vb.pp(format!("layers.{index}")), config))
166+
.map(|index| NomicBertBlock::load(vb.pp(format!("layers.{index}")), index, config))
155167
.collect::<Result<Vec<_>>>()?;
156168

157169
let span = tracing::span!(tracing::Level::TRACE, "encoder");
@@ -170,7 +182,6 @@ impl NomicBertEncoder {
170182

171183
let mut hidden_states = hidden_states.clone();
172184

173-
// Use a loop rather than a fold as it's easier to modify when adding debug/...
174185
for layer in self.layers.iter() {
175186
hidden_states = layer.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?
176187
}
@@ -419,6 +430,7 @@ impl Model for FlashNomicBertModel {
419430
fn is_padded(&self) -> bool {
420431
false
421432
}
433+
422434
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
423435
self.forward(batch)
424436
}

0 commit comments

Comments
 (0)