Skip to content

Commit fe3e6e6

Browse files
committed
.
1 parent d328f42 commit fe3e6e6

File tree

1 file changed

+0
-368
lines changed

1 file changed

+0
-368
lines changed

verl/models/transformers/glm4v.py

Lines changed: 0 additions & 368 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,6 @@
4242
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
4343

4444

45-
if is_flash_attn_2_available():
46-
from flash_attn import flash_attn_func, flash_attn_varlen_func
47-
48-
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
49-
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
50-
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
51-
52-
# if is_npu_available:
53-
# from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
54-
# from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
55-
# from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
56-
57-
# _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
58-
# _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
59-
# _flash_use_top_left_mask = flash_attn_supports_top_left_mask()
60-
61-
_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
62-
6345

6446
def get_rope_index(
6547
processor,
@@ -181,353 +163,3 @@ def get_rope_index(
181163
return position_ids
182164

183165

184-
def prepare_fa2_from_position_ids(
185-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
186-
):
187-
assert position_ids.ndim == 2 # (batch_size, seq_length)
188-
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
189-
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
190-
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
191-
position_ids = position_ids.view(-1)
192-
cu_seqlens = torch.cat(
193-
(
194-
(position_ids == 0).nonzero().view(-1).to(torch.int32),
195-
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
196-
)
197-
)
198-
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
199-
return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))
200-
201-
202-
def _custom_flash_attention_forward(
203-
query_states: torch.Tensor,
204-
key_states: torch.Tensor,
205-
value_states: torch.Tensor,
206-
attention_mask: Optional[torch.Tensor],
207-
query_length: int,
208-
is_causal: bool = True,
209-
position_ids: Optional[torch.Tensor] = None,
210-
use_top_left_mask: bool = False,
211-
deterministic: Optional[bool] = None,
212-
**kwargs,
213-
):
214-
"""
215-
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
216-
"""
217-
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
218-
flash_kwargs = {}
219-
220-
if _flash_supports_deterministic:
221-
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
222-
223-
if kwargs.get("softcap") is not None:
224-
flash_kwargs["softcap"] = kwargs.pop("softcap")
225-
226-
query_states, key_states, value_states = fa_peft_integration_check(
227-
query_states, key_states, value_states, target_dtype=torch.bfloat16
228-
)
229-
230-
if position_ids is not None:
231-
assert position_ids.ndim == 2 # (batch_size, seq_length / sp_size)
232-
233-
sp_size = get_ulysses_sequence_parallel_world_size()
234-
if sp_size > 1:
235-
# qkv: (batch_size, seq_length / sp_size, num_head, head_size)
236-
validate_ulysses_config(query_states.size(2), sp_size)
237-
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
238-
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
239-
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
240-
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
241-
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
242-
position_ids = torch.cat(position_ids_lst, dim=-1) # (batch_size, seq_length)
243-
244-
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
245-
batch_size = query_states.size(0)
246-
q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids(
247-
query_states, key_states, value_states, position_ids
248-
)
249-
attn_output = flash_attn_varlen_func(
250-
q=q,
251-
k=k,
252-
v=v,
253-
cu_seqlens_q=cu_seqlens_q,
254-
cu_seqlens_k=cu_seqlens_k,
255-
max_seqlen_q=max_seqlen_q,
256-
max_seqlen_k=max_seqlen_k,
257-
dropout_p=kwargs.pop("dropout", 0.0),
258-
softmax_scale=kwargs.pop("softmax_scale", None),
259-
causal=is_causal,
260-
**flash_kwargs,
261-
)
262-
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
263-
else:
264-
attn_output = _flash_attention_forward(
265-
query_states,
266-
key_states,
267-
value_states,
268-
attention_mask,
269-
query_length,
270-
is_causal=is_causal,
271-
use_top_left_mask=use_top_left_mask,
272-
deterministic=deterministic,
273-
**kwargs,
274-
) # do not pass position_ids to old flash_attention_forward
275-
276-
if sp_size > 1:
277-
# (batch_size, seq_length, num_head, head_size)
278-
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
279-
280-
return attn_output
281-
282-
283-
def glm4v_attn_forward(
284-
self: "Glm4vTextAttention",
285-
hidden_states: torch.Tensor,
286-
attention_mask: Optional[torch.Tensor] = None,
287-
position_ids: Optional[torch.LongTensor] = None,
288-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
289-
**kwargs,
290-
) -> tuple[torch.Tensor, None, None]:
291-
from transformers.models.glm4v.modeling_glm4v import apply_multimodal_rotary_pos_emb, repeat_kv
292-
293-
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
294-
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
295-
key_states = self.k_proj(hidden_states)
296-
value_states = self.v_proj(hidden_states)
297-
298-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
299-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
300-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
301-
302-
# Because the input can be padded, the absolute sequence length depends on the max position id.
303-
cos, sin = position_embeddings
304-
query_states, key_states = apply_multimodal_rotary_pos_emb(
305-
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
306-
)
307-
key_states = repeat_kv(key_states, self.num_key_value_groups)
308-
value_states = repeat_kv(value_states, self.num_key_value_groups)
309-
dropout_rate = 0.0 if not self.training else self.attention_dropout
310-
311-
# This is before the transpose
312-
q_len = query_states.shape[2]
313-
314-
# FA2 uses non-transposed inputs
315-
query_states = query_states.transpose(1, 2)
316-
key_states = key_states.transpose(1, 2)
317-
value_states = value_states.transpose(1, 2)
318-
319-
attn_output = _custom_flash_attention_forward(
320-
query_states,
321-
key_states,
322-
value_states,
323-
attention_mask,
324-
query_length=q_len,
325-
is_causal=getattr(self, "is_causal", True),
326-
dropout=dropout_rate,
327-
use_top_left_mask=_flash_use_top_left_mask,
328-
position_ids=position_ids, # important: pass position ids
329-
) # (batch_size, seq_length / sp_size, num_head, head_size)
330-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
331-
attn_output = self.o_proj(attn_output)
332-
return attn_output, None
333-
334-
335-
def _get_input_embeds(
336-
model: "Glm4vForConditionalGeneration",
337-
input_ids: torch.LongTensor,
338-
attention_mask: Optional[torch.Tensor] = None,
339-
pixel_values: Optional[torch.FloatTensor] = None,
340-
pixel_values_videos: Optional[torch.FloatTensor] = None,
341-
image_grid_thw: Optional[torch.LongTensor] = None,
342-
video_grid_thw: Optional[torch.LongTensor] = None,
343-
):
344-
inputs_embeds = model.get_input_embeddings()(input_ids)
345-
if pixel_values is not None:
346-
pixel_values = pixel_values.type(model.visual.dtype)
347-
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
348-
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
349-
n_image_features = image_embeds.shape[0]
350-
if n_image_tokens != n_image_features:
351-
raise ValueError(
352-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
353-
)
354-
355-
mask = input_ids == model.config.image_token_id
356-
mask_unsqueezed = mask.unsqueeze(-1)
357-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
358-
image_mask = mask_expanded.to(inputs_embeds.device)
359-
360-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
361-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
362-
363-
if pixel_values_videos is not None:
364-
pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
365-
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
366-
n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
367-
n_video_features = video_embeds.shape[0]
368-
if n_video_tokens != n_video_features:
369-
raise ValueError(
370-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
371-
)
372-
373-
mask = input_ids == model.config.video_token_id
374-
mask_unsqueezed = mask.unsqueeze(-1)
375-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
376-
video_mask = mask_expanded.to(inputs_embeds.device)
377-
378-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
379-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
380-
381-
if model.training and pixel_values is None and pixel_values_videos is None: # handle mixed text-image data
382-
pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
383-
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
384-
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
385-
inputs_embeds += 0.0 * image_embeds.mean()
386-
387-
if attention_mask is not None:
388-
attention_mask = attention_mask.to(inputs_embeds.device)
389-
390-
return inputs_embeds, attention_mask
391-
392-
393-
def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
394-
if position_ids.ndim != 3 or position_ids.size(0) != 4:
395-
# we concat the text position ids with the 3D vision position ids by default
396-
# see https://github.com/huggingface/transformers/pull/39447
397-
raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).")
398-
399-
return position_ids
400-
401-
402-
@dataclass
403-
class Glm4vCausalLMOutputForPPO(Glm4vCausalLMOutputWithPast):
404-
log_probs: Optional[torch.FloatTensor] = None
405-
entropy: Optional[torch.FloatTensor] = None
406-
407-
408-
def glm4v_base_forward(
409-
self: "Glm4vForConditionalGeneration",
410-
input_ids: torch.LongTensor,
411-
attention_mask: Optional[torch.Tensor] = None,
412-
labels: Optional[torch.LongTensor] = None,
413-
pixel_values: Optional[torch.FloatTensor] = None,
414-
pixel_values_videos: Optional[torch.FloatTensor] = None,
415-
image_grid_thw: Optional[torch.LongTensor] = None,
416-
video_grid_thw: Optional[torch.LongTensor] = None,
417-
**kwargs,
418-
):
419-
kwargs["inputs_embeds"], kwargs["attention_mask"] = _get_input_embeds(
420-
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
421-
) # avoid lora module having multiple keyword arguments
422-
return self.language_model(
423-
input_ids=None,
424-
**kwargs,
425-
)
426-
427-
428-
def glm4v_forward(
429-
self: "Glm4vForConditionalGeneration",
430-
input_ids: torch.LongTensor,
431-
attention_mask: Optional[torch.Tensor] = None,
432-
position_ids: Optional[torch.LongTensor] = None,
433-
pixel_values: Optional[torch.FloatTensor] = None,
434-
pixel_values_videos: Optional[torch.FloatTensor] = None,
435-
image_grid_thw: Optional[torch.LongTensor] = None,
436-
video_grid_thw: Optional[torch.LongTensor] = None,
437-
**kwargs,
438-
):
439-
return self.model(
440-
input_ids=input_ids,
441-
attention_mask=attention_mask,
442-
position_ids=process_position_ids(position_ids),
443-
pixel_values=pixel_values,
444-
pixel_values_videos=pixel_values_videos,
445-
image_grid_thw=image_grid_thw,
446-
video_grid_thw=video_grid_thw,
447-
**kwargs,
448-
)
449-
450-
451-
def forward_with_normal_backend(
452-
self: Glm4vForConditionalGeneration,
453-
input_ids: torch.LongTensor = None,
454-
labels: Optional[torch.LongTensor] = None,
455-
temperature: float = 1.0,
456-
**kwargs,
457-
) -> "Glm4vCausalLMOutputWithPast":
458-
outputs = glm4v_forward(self, input_ids, **kwargs)
459-
hidden_states = outputs[0]
460-
logits = self.lm_head(hidden_states)
461-
462-
return Glm4vCausalLMOutputWithPast(
463-
logits=logits,
464-
hidden_states=outputs.hidden_states,
465-
)
466-
467-
468-
def forward_with_torch_backend(
469-
self: Glm4vForConditionalGeneration,
470-
input_ids: torch.LongTensor = None,
471-
labels: Optional[torch.LongTensor] = None,
472-
temperature: float = 1.0,
473-
**kwargs,
474-
) -> tuple | Glm4vCausalLMOutputForPPO:
475-
from verl.utils.experimental.torch_functional import FusedLinearForPPO
476-
477-
outputs = glm4v_forward(self, input_ids, **kwargs)
478-
hidden_states = outputs[0]
479-
480-
# Loss calculations
481-
if labels is not None:
482-
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
483-
elif input_ids is not None:
484-
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
485-
else:
486-
raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.")
487-
488-
fused_linear_for_ppo = FusedLinearForPPO()
489-
log_probs, entropy = fused_linear_for_ppo.forward(
490-
hidden_states=hidden_states,
491-
vocab_weights=self.lm_head.weight,
492-
input_ids=rolled_labels,
493-
temperature=temperature,
494-
)
495-
return Glm4vCausalLMOutputForPPO(
496-
log_probs=log_probs,
497-
entropy=entropy,
498-
hidden_states=outputs.hidden_states,
499-
)
500-
501-
502-
def forward_with_triton_backend(
503-
self: Glm4vForConditionalGeneration,
504-
input_ids: torch.LongTensor = None,
505-
labels: Optional[torch.LongTensor] = None,
506-
temperature: float = 1.0,
507-
**kwargs,
508-
) -> tuple | Glm4vCausalLMOutputForPPO:
509-
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
510-
511-
outputs = glm4v_forward(self, input_ids, **kwargs)
512-
hidden_states = outputs[0]
513-
514-
# Loss calculations
515-
if labels is not None:
516-
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
517-
elif input_ids is not None:
518-
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
519-
else:
520-
raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.")
521-
522-
log_probs, entropy = linear_cross_entropy(
523-
hidden_states,
524-
self.lm_head.weight,
525-
rolled_labels,
526-
temperature,
527-
"none",
528-
)
529-
return Glm4vCausalLMOutputForPPO(
530-
log_probs=log_probs,
531-
entropy=entropy,
532-
hidden_states=outputs.hidden_states,
533-
)

0 commit comments

Comments
 (0)