Skip to content

Commit d1c956d

Browse files
Gemma3n (Text-only) (#20134)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.me>
1 parent dec197e commit d1c956d

File tree

5 files changed

+870
-0
lines changed

5 files changed

+870
-0
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ Specified using `--task generate`.
336336
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
337337
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
338338
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
339+
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
339340
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
340341
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
341342
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
@@ -392,6 +393,9 @@ Specified using `--task generate`.
392393
!!! note
393394
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
394395

396+
!!! note
397+
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
398+
395399
### Pooling Models
396400

397401
See [this page](./pooling_models.md) for more information on how to use pooling models.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def check_available_online(
164164
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
165165
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
166166
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
167+
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
168+
min_transformers_version="4.53"),
167169
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
168170
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
169171
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",

vllm/model_executor/layers/activation.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,57 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
135135
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
136136

137137

138+
@CustomOp.register("gelu_and_mul_sparse")
139+
class GeluAndMulSparse(CustomOp):
140+
"""An activation function for GeluAndMulSparse.
141+
This activation function is used in Gemma3n. It computes:
142+
up_proj = self.up_proj(x)
143+
gate_proj = self.gate_proj(x)
144+
gate_proj = self._gaussian_topk(gate_proj) # sparsity
145+
activations = self.act_fn(gate_proj) # gelu
146+
down_proj = self.down_proj(activations * up_proj)
147+
Shapes:
148+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
149+
return: (num_tokens, d) or (batch_size, seq_len, d)
150+
"""
151+
152+
def __init__(self, activation_sparsity: float, approximate: str = "none"):
153+
super().__init__()
154+
# Gelu.
155+
self.approximate = approximate
156+
if approximate not in ("none", "tanh"):
157+
raise ValueError(f"Unknown approximate mode: {approximate}")
158+
159+
# Sparsity.
160+
if activation_sparsity == 0.0:
161+
raise ValueError(
162+
"activation_sparsity is 0.0. Please use GeluAndMul.")
163+
target_sparsity_tensor = torch.tensor(activation_sparsity,
164+
dtype=torch.float32)
165+
normal_dist = torch.distributions.normal.Normal(0, 1)
166+
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
167+
168+
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
169+
"""Get % sparse percentile of the Gaussian distribution."""
170+
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
171+
# But we do not do this because in expectation they are the same
172+
# and in practice the eval scores are good without gathering.
173+
mean = torch.mean(x, dim=-1, keepdim=True)
174+
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
175+
cutoff_x = mean + std * self.std_multiplier
176+
return nn.functional.relu(x - cutoff_x)
177+
178+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
179+
"""PyTorch-native implementation equivalent to forward()."""
180+
d = x.shape[-1] // 2
181+
out = self._gaussian_topk(x[..., :d])
182+
out = F.gelu(out, approximate=self.approximate)
183+
return out * x[..., d:]
184+
185+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
186+
return self.forward_native(x)
187+
188+
138189
@CustomOp.register("gelu_and_mul")
139190
class GeluAndMul(CustomOp):
140191
"""An activation function for GeGLU.

0 commit comments

Comments
 (0)