Skip to content

Commit cdfb9b5

Browse files
committed
wip kaiju
Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com>
1 parent 623e2ed commit cdfb9b5

File tree

3 files changed

+373
-0
lines changed

3 files changed

+373
-0
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,61 @@ def forward_cuda(
271271
self.forward_static)
272272
self._is_compiled = True
273273
return self.forward_native(x, residual)
274+
275+
276+
277+
@CustomOp.register("kaiju_rms_norm")
278+
class KaijuRMSNorm(CustomOp):
279+
"""RMS normalization for Kaiju.
280+
281+
Differences from standard RMSNorm:
282+
1. No learnable weight parameter
283+
2. Clams output to be in range of [-4, 4]
284+
3. Calculation is done in fp32 and then converted to orig_dtype
285+
"""
286+
287+
def __init__(
288+
self,
289+
hidden_size: int,
290+
eps: float = 1e-6,
291+
) -> None:
292+
super().__init__()
293+
self.variance_epsilon = eps
294+
self.hidden_size = hidden_size
295+
296+
@staticmethod
297+
def forward_static(
298+
variance_epsilon: float,
299+
x: torch.Tensor,
300+
) -> torch.Tensor:
301+
"""PyTorch-native implementation equivalent to forward()."""
302+
input_dtype = x.dtype
303+
x = x.to(torch.float32)
304+
variance = x.pow(2).mean(dim=-1, keepdim=True)
305+
x = x * torch.rsqrt(variance + variance_epsilon)
306+
return x.to(input_dtype).clamp(-4, 4)
307+
308+
def forward_native(
309+
self,
310+
x: torch.Tensor,
311+
) -> torch.Tensor:
312+
"""PyTorch-native implementation equivalent to forward()."""
313+
return self.forward_static(self.variance_epsilon, x)
314+
315+
def forward_cuda(
316+
self,
317+
x: torch.Tensor,
318+
) -> torch.Tensor:
319+
if torch.compiler.is_compiling():
320+
return self.forward_native(x)
321+
322+
if not getattr(self, "_is_compiled", False):
323+
self.forward_static = torch.compile( # type: ignore
324+
self.forward_static)
325+
self._is_compiled = True
326+
return self.forward_native(x, residual)
327+
328+
def extra_repr(self) -> str:
329+
s = f"hidden_size={self.hidden_size}"
330+
s += f", eps={self.variance_epsilon}"
331+
return s

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
@@ -74,6 +75,29 @@ def _apply_rotary_emb(
7475
else:
7576
return torch.stack((o1, o2), dim=-1).flatten(-2)
7677

78+
@CustomOp.register("kaiju_rotary_embedding")
79+
class KaijuRotaryEmbedding(CustomOp):
80+
def __init__(
81+
self,
82+
hf_config: PretrainedConfig,
83+
dtype: torch.dtype
84+
) -> None:
85+
super().__init__()
86+
# BC: "rope_type" was originally "type"
87+
if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None:
88+
self.rope_type = hf_config.rope_scaling.get("rope_type", hf_config.rope_scaling.get("type"))
89+
else:
90+
self.rope_type = "default"
91+
self.max_seq_len_cached = hf_config.max_position_embeddings
92+
self.original_max_seq_len = hf_config.max_position_embeddings
93+
94+
self.config = hf_config
95+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
96+
97+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
98+
self.register_buffer("inv_freq", inv_freq, persistent=False)
99+
self.original_inv_freq = self.inv_freq
100+
77101

78102
@CustomOp.register("rotary_embedding")
79103
class RotaryEmbedding(CustomOp):

vllm/model_executor/models/kaiju.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
from typing import Iterable, Optional, Set, Tuple, Union
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn
6+
from transformers import ACT2FN
7+
8+
from kaiju import KaijuTextConfig
9+
10+
from vllm.config import CacheConfig, VllmConfig
11+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
12+
from vllm.logger import init_logger
13+
from vllm.model_executor.layers.layernorm import KaijuRMSNorm
14+
from vllm.model_executor.layers.quantization import QuantizationConfig
15+
16+
# from vllm.attention import Attention
17+
# from vllm.compilation.decorators import support_torch_compile
18+
19+
# from vllm.logger import init_logger
20+
# from vllm.model_executor.layers.activation import GeluAndMul
21+
22+
# from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23+
# QKVParallelLinear,
24+
# RowParallelLinear)
25+
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
26+
# from vllm.model_executor.layers.rotary_embedding import get_rope
27+
# from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
28+
# from vllm.model_executor.layers.vocab_parallel_embedding import (
29+
# VocabParallelEmbedding)
30+
# from vllm.model_executor.model_loader.weight_utils import (
31+
# default_weight_loader, maybe_remap_kv_scale_name)
32+
# from vllm.model_executor.sampling_metadata import SamplingMetadata
33+
# from vllm.sequence import IntermediateTensors
34+
35+
# from .interfaces import SupportsLoRA, SupportsPP
36+
from .utils import (AutoWeightsLoader, extract_layer_index,
37+
is_pp_missing_parameter,
38+
make_empty_intermediate_tensors_factory, make_layers,
39+
maybe_prefix)
40+
41+
logger = init_logger(__name__)
42+
43+
class KaijuMLP(nn.Module):
44+
def __init__(self,
45+
hidden_size: int,
46+
intermediate_size: int,
47+
hidden_act: str,
48+
rms_norm_eps: float,
49+
):
50+
super().__init__()
51+
self.hidden_size = hidden_size
52+
self.intermediate_size = intermediate_size
53+
54+
self.residual_scale = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False)
55+
self.pre_ffn_norm = KaijuRMSNorm(self.hidden_size, eps=rms_norm_eps)
56+
57+
# TODO: Megatron style TP (MergedColumnParallelLinear then RowParallelLinear)
58+
self.W_in = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
59+
self.W_out = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
60+
self.act_fn = ACT2FN[config.hidden_act]
61+
62+
def forward(self, x):
63+
# WARNING: In whippet checkpoints, there is an `args["quantize"]["ffn_clamp_middle_output"]`
64+
# It's only used in the backward pass in specific circumstances.
65+
hidden_states = x
66+
x = self.W_in(x)
67+
x = clamp(x, 4)
68+
x = self.act_fn(x)
69+
x = self.W_out(x)
70+
hidden_states *= self.residual_scale
71+
return x + hidden_states
72+
73+
@dataclass
74+
class KaijuCache:
75+
key_states : Optional[torch.Tensor] = None
76+
value_states : Optional[torch.Tensor] = None
77+
78+
class KaijuAttention(nn.Module):
79+
def __init__(self,
80+
config: KaijuTextConfig,
81+
max_position_embeddings: int,
82+
is_context_encoder: bool,
83+
cache_config: Optional[CacheConfig] = None,
84+
quant_config: Optional[QuantizationConfig] = None,
85+
attn_logits_soft_cap: Optional[float] = None,
86+
prefix: str = ""
87+
):
88+
super().__init__()
89+
self.config = config
90+
self.hidden_size = config.hidden_size
91+
self.is_context_encoder = is_context_encoder
92+
tp_size = get_tensor_model_parallel_world_size()
93+
self.total_num_heads = config.num_attention_heads
94+
assert self.total_num_heads % tp_size == 0
95+
self.num_heads = self.total_num_heads // tp_size
96+
self.total_num_kv_heads = config.num_key_value_heads
97+
if self.total_num_kv_heads >= tp_size:
98+
# Number of KV heads is greater than TP size, so we partition
99+
# the KV heads across multiple tensor parallel GPUs.
100+
assert self.total_num_kv_heads % tp_size == 0
101+
else:
102+
# Number of KV heads is less than TP size, so we replicate
103+
# the KV heads across multiple tensor parallel GPUs.
104+
assert tp_size % self.total_num_kv_heads == 0
105+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
106+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
107+
self.q_size = self.num_heads * self.head_dim
108+
self.kv_size = self.num_kv_heads * self.head_dim
109+
self.scaling = self.head_dim**-0.5
110+
111+
# TODO: Combine into single proj matrix and use QKVParallelLinear
112+
self.q_proj = nn.Linear(
113+
self.hidden_size, self.q_size, bias=False
114+
)
115+
if not self.is_context_encoder:
116+
self.k_proj = nn.Linear(
117+
self.hidden_size, self.kv_size, bias=False
118+
)
119+
self.v_proj = nn.Linear(
120+
self.hidden_size, self.kv_size, bias=False
121+
)
122+
123+
# TODO: Use RowParallelLinear
124+
self.o_proj = nn.Linear(
125+
self.num_heads * self.head_dim, self.hidden_size, bias=False
126+
)
127+
128+
self.pre_projection_norm = KaijuRMSNorm(self.config.hidden_size, eps=config.rms_norm_eps)
129+
130+
layer_idx = extract_layer_index(prefix)
131+
self.is_sliding = layer_idx not in self.config.global_attention_layer_schedule
132+
if self.is_sliding:
133+
self.sliding_window = 1024
134+
else:
135+
self.sliding_window = None
136+
137+
self.attn = Attention(
138+
self.num_heads,
139+
self.head_dim,
140+
self.scaling,
141+
num_kv_heads=self.num_kv_heads,
142+
cache_config=cache_config,
143+
quant_config=quant_config,
144+
logits_soft_cap=attn_logits_soft_cap,
145+
per_layer_sliding_window=self.sliding_window,
146+
prefix=f"{prefix}.attn"
147+
)
148+
149+
def forward(
150+
self,
151+
positions_embeddings: Tuple[torch.Tensor, torch.Tensor],
152+
hidden_states: torch.Tensor,
153+
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
154+
) -> torch.Tensor:
155+
156+
processed_hidden_states = self.pre_projection_norm(hidden_states)
157+
input_shape = hidden_states.shape[:-1]
158+
hidden_shape = (*input_shape, -1, self.head_dim)
159+
160+
cos, sin = position_embeddings
161+
query_states = self.q_proj(processed_hidden_states).view(hidden_shape)
162+
163+
if self.is_context_encoder:
164+
assert kv_cache is None
165+
key_states = kv_cache.key_states
166+
value_states = kv_cache.value_states
167+
else:
168+
key_states = self.k_proj(processed_hidden_states).view(hidden_shape)
169+
value_states = self.v_proj(processed_hidden_states).view(hidden_shape)
170+
171+
if kv_cache is not None:
172+
key_states = kv_cache.key_states
173+
value_states = kv_cache.value_states
174+
175+
176+
# We should probably cache the clamped values.
177+
query_states = clamp(query_states, 4)
178+
key_states = clamp(key_states, 4)
179+
value_states = clamp(value_states, 4)
180+
181+
# Should we cache post rope?
182+
query_states, key_states = apply_rotary_pos_emb_kaiju(query_states, key_states, cos, sin, unsqueeze_dim=2)
183+
184+
# TODO: attention masking
185+
attn_output = self.attn(query_states, key_states, value_states)
186+
187+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
188+
attn_output = self.o_proj(attn_output)
189+
190+
hidden_states *= self.residual_scale
191+
hidden_states += attn_output
192+
193+
return hidden_states
194+
195+
class KaijuDecoderLayer(nn.Module):
196+
def __init__(
197+
self,
198+
config: KaijuTextConfig,
199+
is_context_encoder: bool,
200+
cache_config: Optional[CacheConfig] = None,
201+
quant_config: Optional[QuantizationConfig] = None,
202+
prefix: str = ""
203+
):
204+
super().__init__()
205+
self.hidden_size = config.hidden_size
206+
self.self_attn = KaijuAttention(
207+
config=config,
208+
max_position_embeddings=config.max_position_embeddings,
209+
is_context_encoder=is_context_encoder,
210+
cache_config=cache_config,
211+
quant_config=quant_config,
212+
attn_logits_soft_cap=None,
213+
prefix=f"{prefix}.self_attn"
214+
)
215+
216+
self.mlp = KaijuMLP(
217+
hidden_size=self.hidden_size,
218+
intermediate_size=config.intermediate_size,
219+
hidden_act=config.hidden_act,
220+
rms_norm_eps=config.rms_norm_eps,
221+
)
222+
223+
def forward(
224+
self,
225+
positions_embeddings: Tuple[torch.Tensor, torch.Tensor],
226+
hidden_states: torch.Tensor,
227+
output_attentions: bool = False,
228+
kv_cache: Optional[KaijuCache] = None
229+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
230+
# Self Attention
231+
# attention module handles the residual stream update.
232+
hidden_states = self.self_attn(
233+
hidden_states=hidden_states,
234+
position_embeddings=position_embeddings,
235+
kv_cache=kv_cache,
236+
)
237+
238+
# Fully Connected
239+
hidden_states = self.mlp(hidden_states)
240+
241+
outputs = (hidden_states,)
242+
# This isn't necessary for inference, we can consider writing a slow
243+
# attention implementation for debugging purposes.
244+
assert not output_attentions, "TODO: Support this"
245+
246+
return outputs
247+
248+
@support_torch_compile
249+
class KaijuModel(nn.Module):
250+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
251+
super().__init__()
252+
config = vllm_config.model_config.hf_config
253+
cache_config = vllm_config.cache_config
254+
quant_config = vllm_config.quant_config
255+
self.config = config
256+
self.quant_config = quant_config
257+
258+
self.layer_to_kv_group = list(range(config.num_hidden_layers))
259+
for layers in config.share_kv_schedule:
260+
for layer_idx in layers:
261+
self.layer_to_kv_group[layer_idx] = min(layers)
262+
263+
self.padding_idx = config.pad_token_id
264+
self.vocab_size = config.vocab_size
265+
266+
# Vocab parallel embedding
267+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
268+
# TODO: Get rid of this scale by "compiling" it into the embedding weights, then
269+
# when we convert the lm head/etc we can just adjust that scale.
270+
self.embedding_scale = nn.Parameter(torch.FloatTensor([0]), requires_grad=False)
271+
272+
self.start_layer, self.end_layer, self.layers = make_layers_with_idx(
273+
config.num_hidden_layers,
274+
lambda prefix, idx: KaijuDecoderLayer(
275+
config, is_context_encoder=idx != self.layer_to_kv_group[idx], cache_config=cache_config, quant_config=quant_config, prefix=prefix
276+
),
277+
prefix=f"{prefix}.layers"
278+
)
279+
280+
281+
282+
283+
284+
285+
286+
287+
288+
289+
290+
291+

0 commit comments

Comments
 (0)