|
24 | 24 | # limitations under the License.
|
25 | 25 | """Inference-only GraniteMoe model."""
|
26 | 26 | from collections.abc import Iterable
|
27 |
| -from typing import Optional |
| 27 | +from typing import Any, Optional |
28 | 28 |
|
29 | 29 | import torch
|
30 | 30 | from torch import nn
|
@@ -113,6 +113,7 @@ def __init__(
|
113 | 113 | num_kv_heads: int,
|
114 | 114 | max_position: int = 4096 * 32,
|
115 | 115 | rope_theta: float = 10000,
|
| 116 | + rope_scaling: Optional[dict[str, Any]] = None, |
116 | 117 | cache_config: Optional[CacheConfig] = None,
|
117 | 118 | quant_config: Optional[QuantizationConfig] = None,
|
118 | 119 | attention_multiplier: Optional[float] = None,
|
@@ -163,6 +164,7 @@ def __init__(
|
163 | 164 | max_position=max_position,
|
164 | 165 | base=int(self.rope_theta),
|
165 | 166 | is_neox_style=True,
|
| 167 | + rope_scaling=rope_scaling, |
166 | 168 | )
|
167 | 169 | self.attn = Attention(self.num_heads,
|
168 | 170 | self.head_dim,
|
@@ -198,12 +200,14 @@ def __init__(
|
198 | 200 | self.hidden_size = config.hidden_size
|
199 | 201 | # Requires transformers > 4.32.0
|
200 | 202 | rope_theta = getattr(config, "rope_theta", 10000)
|
| 203 | + rope_scaling = getattr(config, "rope_scaling", None) |
201 | 204 | self.self_attn = GraniteMoeAttention(
|
202 | 205 | hidden_size=self.hidden_size,
|
203 | 206 | num_heads=config.num_attention_heads,
|
204 | 207 | max_position=config.max_position_embeddings,
|
205 | 208 | num_kv_heads=config.num_key_value_heads,
|
206 | 209 | rope_theta=rope_theta,
|
| 210 | + rope_scaling=rope_scaling, |
207 | 211 | cache_config=cache_config,
|
208 | 212 | quant_config=quant_config,
|
209 | 213 | prefix=f"{prefix}.self_attn",
|
|
0 commit comments