Skip to content

Commit ed8cbfe

Browse files
authored
Let GraniteMoeAttention use YaRN (#21174)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 45badd0 commit ed8cbfe

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

vllm/model_executor/models/granitemoe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# limitations under the License.
2525
"""Inference-only GraniteMoe model."""
2626
from collections.abc import Iterable
27-
from typing import Optional
27+
from typing import Any, Optional
2828

2929
import torch
3030
from torch import nn
@@ -113,6 +113,7 @@ def __init__(
113113
num_kv_heads: int,
114114
max_position: int = 4096 * 32,
115115
rope_theta: float = 10000,
116+
rope_scaling: Optional[dict[str, Any]] = None,
116117
cache_config: Optional[CacheConfig] = None,
117118
quant_config: Optional[QuantizationConfig] = None,
118119
attention_multiplier: Optional[float] = None,
@@ -163,6 +164,7 @@ def __init__(
163164
max_position=max_position,
164165
base=int(self.rope_theta),
165166
is_neox_style=True,
167+
rope_scaling=rope_scaling,
166168
)
167169
self.attn = Attention(self.num_heads,
168170
self.head_dim,
@@ -198,12 +200,14 @@ def __init__(
198200
self.hidden_size = config.hidden_size
199201
# Requires transformers > 4.32.0
200202
rope_theta = getattr(config, "rope_theta", 10000)
203+
rope_scaling = getattr(config, "rope_scaling", None)
201204
self.self_attn = GraniteMoeAttention(
202205
hidden_size=self.hidden_size,
203206
num_heads=config.num_attention_heads,
204207
max_position=config.max_position_embeddings,
205208
num_kv_heads=config.num_key_value_heads,
206209
rope_theta=rope_theta,
210+
rope_scaling=rope_scaling,
207211
cache_config=cache_config,
208212
quant_config=quant_config,
209213
prefix=f"{prefix}.self_attn",

vllm/model_executor/models/granitemoeshared.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,14 @@ def __init__(
8181
self.hidden_size = config.hidden_size
8282
# Requires transformers > 4.32.0
8383
rope_theta = getattr(config, "rope_theta", 10000)
84+
rope_scaling = getattr(config, "rope_scaling", None)
8485
self.self_attn = GraniteMoeAttention(
8586
hidden_size=self.hidden_size,
8687
num_heads=config.num_attention_heads,
8788
max_position=config.max_position_embeddings,
8889
num_kv_heads=config.num_key_value_heads,
8990
rope_theta=rope_theta,
91+
rope_scaling=rope_scaling,
9092
cache_config=cache_config,
9193
quant_config=quant_config,
9294
prefix=f"{prefix}.self_attn",

0 commit comments

Comments
 (0)