4242logger .setLevel (os .getenv ("VERL_LOGGING_LEVEL" , "WARN" ))
4343
4444
45- if  is_flash_attn_2_available ():
46-     from  flash_attn  import  flash_attn_func , flash_attn_varlen_func 
47- 
48-     _flash_supports_window_size  =  "window_size"  in  inspect .signature (flash_attn_func ).parameters 
49-     _flash_supports_deterministic  =  "deterministic"  in  inspect .signature (flash_attn_func ).parameters 
50-     _flash_use_top_left_mask  =  not  is_flash_attn_greater_or_equal_2_10 ()
51- 
52- # if is_npu_available: 
53- #     from transformers.integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func 
54- #     from transformers.integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func 
55- #     from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask 
56- 
57- #     _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters 
58- #     _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters 
59- #     _flash_use_top_left_mask = flash_attn_supports_top_left_mask() 
60- 
61- _flash_deterministic_enabled  =  os .getenv ("FLASH_ATTENTION_DETERMINISTIC" , "0" ) ==  "1" 
62- 
6345
6446def  get_rope_index (
6547    processor ,
@@ -181,353 +163,3 @@ def get_rope_index(
181163    return  position_ids 
182164
183165
184- def  prepare_fa2_from_position_ids (
185-     query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , position_ids : torch .Tensor 
186- ):
187-     assert  position_ids .ndim  ==  2   # (batch_size, seq_length) 
188-     query  =  query .contiguous ().view (- 1 , query .size (- 2 ), query .size (- 1 ))
189-     key  =  key .contiguous ().view (- 1 , key .size (- 2 ), key .size (- 1 ))
190-     value  =  value .contiguous ().view (- 1 , value .size (- 2 ), value .size (- 1 ))
191-     position_ids  =  position_ids .view (- 1 )
192-     cu_seqlens  =  torch .cat (
193-         (
194-             (position_ids  ==  0 ).nonzero ().view (- 1 ).to (torch .int32 ),
195-             torch .tensor (position_ids .size (), device = position_ids .device , dtype = torch .int32 ),
196-         )
197-     )
198-     max_length  =  cu_seqlens .diff ().max ()  # use cu_seqlens to infer max_length for qwen2vl mrope 
199-     return  (query , key , value , (cu_seqlens , cu_seqlens ), (max_length , max_length ))
200- 
201- 
202- def  _custom_flash_attention_forward (
203-     query_states : torch .Tensor ,
204-     key_states : torch .Tensor ,
205-     value_states : torch .Tensor ,
206-     attention_mask : Optional [torch .Tensor ],
207-     query_length : int ,
208-     is_causal : bool  =  True ,
209-     position_ids : Optional [torch .Tensor ] =  None ,
210-     use_top_left_mask : bool  =  False ,
211-     deterministic : Optional [bool ] =  None ,
212-     ** kwargs ,
213- ):
214-     """ 
215-     Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) 
216-     """ 
217-     # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 
218-     flash_kwargs  =  {}
219- 
220-     if  _flash_supports_deterministic :
221-         flash_kwargs ["deterministic" ] =  deterministic  if  deterministic  is  not None  else  _flash_deterministic_enabled 
222- 
223-     if  kwargs .get ("softcap" ) is  not None :
224-         flash_kwargs ["softcap" ] =  kwargs .pop ("softcap" )
225- 
226-     query_states , key_states , value_states  =  fa_peft_integration_check (
227-         query_states , key_states , value_states , target_dtype = torch .bfloat16 
228-     )
229- 
230-     if  position_ids  is  not None :
231-         assert  position_ids .ndim  ==  2   # (batch_size, seq_length / sp_size) 
232- 
233-     sp_size  =  get_ulysses_sequence_parallel_world_size ()
234-     if  sp_size  >  1 :
235-         # qkv: (batch_size, seq_length / sp_size, num_head, head_size) 
236-         validate_ulysses_config (query_states .size (2 ), sp_size )
237-         query_states  =  gather_seq_scatter_heads (query_states , seq_dim = 1 , head_dim = 2 )
238-         key_states  =  gather_seq_scatter_heads (key_states , seq_dim = 1 , head_dim = 2 )
239-         value_states  =  gather_seq_scatter_heads (value_states , seq_dim = 1 , head_dim = 2 )
240-         position_ids_lst  =  [torch .empty_like (position_ids ) for  _  in  range (sp_size )]
241-         position_ids  =  dist .all_gather (position_ids_lst , position_ids , group = get_ulysses_sequence_parallel_group ())
242-         position_ids  =  torch .cat (position_ids_lst , dim = - 1 )  # (batch_size, seq_length) 
243- 
244-     if  position_ids  is  not None  and  query_length  !=  1  and  not  (torch .diff (position_ids , dim = - 1 ) >=  0 ).all ():
245-         batch_size  =  query_states .size (0 )
246-         q , k , v , (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  prepare_fa2_from_position_ids (
247-             query_states , key_states , value_states , position_ids 
248-         )
249-         attn_output  =  flash_attn_varlen_func (
250-             q = q ,
251-             k = k ,
252-             v = v ,
253-             cu_seqlens_q = cu_seqlens_q ,
254-             cu_seqlens_k = cu_seqlens_k ,
255-             max_seqlen_q = max_seqlen_q ,
256-             max_seqlen_k = max_seqlen_k ,
257-             dropout_p = kwargs .pop ("dropout" , 0.0 ),
258-             softmax_scale = kwargs .pop ("softmax_scale" , None ),
259-             causal = is_causal ,
260-             ** flash_kwargs ,
261-         )
262-         attn_output  =  attn_output .view (batch_size , - 1 , attn_output .size (- 2 ), attn_output .size (- 1 ))
263-     else :
264-         attn_output  =  _flash_attention_forward (
265-             query_states ,
266-             key_states ,
267-             value_states ,
268-             attention_mask ,
269-             query_length ,
270-             is_causal = is_causal ,
271-             use_top_left_mask = use_top_left_mask ,
272-             deterministic = deterministic ,
273-             ** kwargs ,
274-         )  # do not pass position_ids to old flash_attention_forward 
275- 
276-     if  sp_size  >  1 :
277-         # (batch_size, seq_length, num_head, head_size) 
278-         attn_output  =  gather_heads_scatter_seq (attn_output , head_dim = 2 , seq_dim = 1 )
279- 
280-     return  attn_output 
281- 
282- 
283- def  glm4v_attn_forward (
284-     self : "Glm4vTextAttention" ,
285-     hidden_states : torch .Tensor ,
286-     attention_mask : Optional [torch .Tensor ] =  None ,
287-     position_ids : Optional [torch .LongTensor ] =  None ,
288-     position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] =  None ,  # will become mandatory in v4.46 
289-     ** kwargs ,
290- ) ->  tuple [torch .Tensor , None , None ]:
291-     from  transformers .models .glm4v .modeling_glm4v  import  apply_multimodal_rotary_pos_emb , repeat_kv 
292- 
293-     bsz , q_len , _  =  hidden_states .size ()  # q_len = seq_length / sp_size 
294-     query_states  =  self .q_proj (hidden_states )  # (batch_size, seq_length / sp_size, num_heads * head_size) 
295-     key_states  =  self .k_proj (hidden_states )
296-     value_states  =  self .v_proj (hidden_states )
297- 
298-     query_states  =  query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
299-     key_states  =  key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
300-     value_states  =  value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
301- 
302-     # Because the input can be padded, the absolute sequence length depends on the max position id. 
303-     cos , sin  =  position_embeddings 
304-     query_states , key_states  =  apply_multimodal_rotary_pos_emb (
305-         query_states , key_states , cos , sin , self .rope_scaling ["mrope_section" ]
306-     )
307-     key_states  =  repeat_kv (key_states , self .num_key_value_groups )
308-     value_states  =  repeat_kv (value_states , self .num_key_value_groups )
309-     dropout_rate  =  0.0  if  not  self .training  else  self .attention_dropout 
310- 
311-     # This is before the transpose 
312-     q_len  =  query_states .shape [2 ]
313- 
314-     # FA2 uses non-transposed inputs 
315-     query_states  =  query_states .transpose (1 , 2 )
316-     key_states  =  key_states .transpose (1 , 2 )
317-     value_states  =  value_states .transpose (1 , 2 )
318- 
319-     attn_output  =  _custom_flash_attention_forward (
320-         query_states ,
321-         key_states ,
322-         value_states ,
323-         attention_mask ,
324-         query_length = q_len ,
325-         is_causal = getattr (self , "is_causal" , True ),
326-         dropout = dropout_rate ,
327-         use_top_left_mask = _flash_use_top_left_mask ,
328-         position_ids = position_ids ,  # important: pass position ids 
329-     )  # (batch_size, seq_length / sp_size, num_head, head_size) 
330-     attn_output  =  attn_output .reshape (bsz , q_len , self .hidden_size ).contiguous ()
331-     attn_output  =  self .o_proj (attn_output )
332-     return  attn_output , None 
333- 
334- 
335- def  _get_input_embeds (
336-     model : "Glm4vForConditionalGeneration" ,
337-     input_ids : torch .LongTensor ,
338-     attention_mask : Optional [torch .Tensor ] =  None ,
339-     pixel_values : Optional [torch .FloatTensor ] =  None ,
340-     pixel_values_videos : Optional [torch .FloatTensor ] =  None ,
341-     image_grid_thw : Optional [torch .LongTensor ] =  None ,
342-     video_grid_thw : Optional [torch .LongTensor ] =  None ,
343- ):
344-     inputs_embeds  =  model .get_input_embeddings ()(input_ids )
345-     if  pixel_values  is  not None :
346-         pixel_values  =  pixel_values .type (model .visual .dtype )
347-         image_embeds  =  model .visual (pixel_values , grid_thw = image_grid_thw )
348-         n_image_tokens  =  (input_ids  ==  model .config .image_token_id ).sum ().item ()
349-         n_image_features  =  image_embeds .shape [0 ]
350-         if  n_image_tokens  !=  n_image_features :
351-             raise  ValueError (
352-                 f"Image features and image tokens do not match: tokens: { n_image_tokens } { n_image_features }  
353-             )
354- 
355-         mask  =  input_ids  ==  model .config .image_token_id 
356-         mask_unsqueezed  =  mask .unsqueeze (- 1 )
357-         mask_expanded  =  mask_unsqueezed .expand_as (inputs_embeds )
358-         image_mask  =  mask_expanded .to (inputs_embeds .device )
359- 
360-         image_embeds  =  image_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
361-         inputs_embeds  =  inputs_embeds .masked_scatter (image_mask , image_embeds )
362- 
363-     if  pixel_values_videos  is  not None :
364-         pixel_values_videos  =  pixel_values_videos .type (model .visual .dtype )
365-         video_embeds  =  model .visual (pixel_values_videos , grid_thw = video_grid_thw )
366-         n_video_tokens  =  (input_ids  ==  model .config .video_token_id ).sum ().item ()
367-         n_video_features  =  video_embeds .shape [0 ]
368-         if  n_video_tokens  !=  n_video_features :
369-             raise  ValueError (
370-                 f"Video features and video tokens do not match: tokens: { n_video_tokens } { n_video_features }  
371-             )
372- 
373-         mask  =  input_ids  ==  model .config .video_token_id 
374-         mask_unsqueezed  =  mask .unsqueeze (- 1 )
375-         mask_expanded  =  mask_unsqueezed .expand_as (inputs_embeds )
376-         video_mask  =  mask_expanded .to (inputs_embeds .device )
377- 
378-         video_embeds  =  video_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
379-         inputs_embeds  =  inputs_embeds .masked_scatter (video_mask , video_embeds )
380- 
381-     if  model .training  and  pixel_values  is  None  and  pixel_values_videos  is  None :  # handle mixed text-image data 
382-         pixel_values  =  torch .zeros ((16 , 1176 ), dtype = inputs_embeds .dtype , device = inputs_embeds .device )
383-         image_grid_thw  =  torch .tensor ([[1 , 4 , 4 ]], dtype = torch .long , device = inputs_embeds .device )
384-         image_embeds  =  model .visual (pixel_values , grid_thw = image_grid_thw )
385-         inputs_embeds  +=  0.0  *  image_embeds .mean ()
386- 
387-     if  attention_mask  is  not None :
388-         attention_mask  =  attention_mask .to (inputs_embeds .device )
389- 
390-     return  inputs_embeds , attention_mask 
391- 
392- 
393- def  process_position_ids (position_ids : torch .Tensor ) ->  torch .Tensor :
394-     if  position_ids .ndim  !=  3  or  position_ids .size (0 ) !=  4 :
395-         # we concat the text position ids with the 3D vision position ids by default 
396-         # see https://github.com/huggingface/transformers/pull/39447 
397-         raise  ValueError ("position_ids should be a 3D tensor of shape (4, batch_size, seq_length)." )
398- 
399-     return  position_ids 
400- 
401- 
402- @dataclass  
403- class  Glm4vCausalLMOutputForPPO (Glm4vCausalLMOutputWithPast ):
404-     log_probs : Optional [torch .FloatTensor ] =  None 
405-     entropy : Optional [torch .FloatTensor ] =  None 
406- 
407- 
408- def  glm4v_base_forward (
409-     self : "Glm4vForConditionalGeneration" ,
410-     input_ids : torch .LongTensor ,
411-     attention_mask : Optional [torch .Tensor ] =  None ,
412-     labels : Optional [torch .LongTensor ] =  None ,
413-     pixel_values : Optional [torch .FloatTensor ] =  None ,
414-     pixel_values_videos : Optional [torch .FloatTensor ] =  None ,
415-     image_grid_thw : Optional [torch .LongTensor ] =  None ,
416-     video_grid_thw : Optional [torch .LongTensor ] =  None ,
417-     ** kwargs ,
418- ):
419-     kwargs ["inputs_embeds" ], kwargs ["attention_mask" ] =  _get_input_embeds (
420-         self , input_ids , attention_mask , pixel_values , pixel_values_videos , image_grid_thw , video_grid_thw 
421-     )  # avoid lora module having multiple keyword arguments 
422-     return  self .language_model (
423-         input_ids = None ,
424-         ** kwargs ,
425-     )
426- 
427- 
428- def  glm4v_forward (
429-     self : "Glm4vForConditionalGeneration" ,
430-     input_ids : torch .LongTensor ,
431-     attention_mask : Optional [torch .Tensor ] =  None ,
432-     position_ids : Optional [torch .LongTensor ] =  None ,
433-     pixel_values : Optional [torch .FloatTensor ] =  None ,
434-     pixel_values_videos : Optional [torch .FloatTensor ] =  None ,
435-     image_grid_thw : Optional [torch .LongTensor ] =  None ,
436-     video_grid_thw : Optional [torch .LongTensor ] =  None ,
437-     ** kwargs ,
438- ):
439-     return  self .model (
440-         input_ids = input_ids ,
441-         attention_mask = attention_mask ,
442-         position_ids = process_position_ids (position_ids ),
443-         pixel_values = pixel_values ,
444-         pixel_values_videos = pixel_values_videos ,
445-         image_grid_thw = image_grid_thw ,
446-         video_grid_thw = video_grid_thw ,
447-         ** kwargs ,
448-     )
449- 
450- 
451- def  forward_with_normal_backend (
452-     self : Glm4vForConditionalGeneration ,
453-     input_ids : torch .LongTensor  =  None ,
454-     labels : Optional [torch .LongTensor ] =  None ,
455-     temperature : float  =  1.0 ,
456-     ** kwargs ,
457- ) ->  "Glm4vCausalLMOutputWithPast" :
458-     outputs  =  glm4v_forward (self , input_ids , ** kwargs )
459-     hidden_states  =  outputs [0 ]
460-     logits  =  self .lm_head (hidden_states )
461- 
462-     return  Glm4vCausalLMOutputWithPast (
463-         logits = logits ,
464-         hidden_states = outputs .hidden_states ,
465-     )
466- 
467- 
468- def  forward_with_torch_backend (
469-     self : Glm4vForConditionalGeneration ,
470-     input_ids : torch .LongTensor  =  None ,
471-     labels : Optional [torch .LongTensor ] =  None ,
472-     temperature : float  =  1.0 ,
473-     ** kwargs ,
474- ) ->  tuple  |  Glm4vCausalLMOutputForPPO :
475-     from  verl .utils .experimental .torch_functional  import  FusedLinearForPPO 
476- 
477-     outputs  =  glm4v_forward (self , input_ids , ** kwargs )
478-     hidden_states  =  outputs [0 ]
479- 
480-     # Loss calculations 
481-     if  labels  is  not None :
482-         rolled_labels  =  torch .roll (labels , shifts = - 1 , dims = - 1 )
483-     elif  input_ids  is  not None :
484-         rolled_labels  =  torch .roll (input_ids , shifts = - 1 , dims = - 1 )
485-     else :
486-         raise  RuntimeError ("To use forward_with_torch_backend, either labels or input_ids must be provided." )
487- 
488-     fused_linear_for_ppo  =  FusedLinearForPPO ()
489-     log_probs , entropy  =  fused_linear_for_ppo .forward (
490-         hidden_states = hidden_states ,
491-         vocab_weights = self .lm_head .weight ,
492-         input_ids = rolled_labels ,
493-         temperature = temperature ,
494-     )
495-     return  Glm4vCausalLMOutputForPPO (
496-         log_probs = log_probs ,
497-         entropy = entropy ,
498-         hidden_states = outputs .hidden_states ,
499-     )
500- 
501- 
502- def  forward_with_triton_backend (
503-     self : Glm4vForConditionalGeneration ,
504-     input_ids : torch .LongTensor  =  None ,
505-     labels : Optional [torch .LongTensor ] =  None ,
506-     temperature : float  =  1.0 ,
507-     ** kwargs ,
508- ) ->  tuple  |  Glm4vCausalLMOutputForPPO :
509-     from  verl .utils .kernel .linear_cross_entropy  import  linear_cross_entropy 
510- 
511-     outputs  =  glm4v_forward (self , input_ids , ** kwargs )
512-     hidden_states  =  outputs [0 ]
513- 
514-     # Loss calculations 
515-     if  labels  is  not None :
516-         rolled_labels  =  torch .roll (labels , shifts = - 1 , dims = - 1 )
517-     elif  input_ids  is  not None :
518-         rolled_labels  =  torch .roll (input_ids , shifts = - 1 , dims = - 1 )
519-     else :
520-         raise  RuntimeError ("To use forward_with_triton_backend, either labels or input_ids must be provided." )
521- 
522-     log_probs , entropy  =  linear_cross_entropy (
523-         hidden_states ,
524-         self .lm_head .weight ,
525-         rolled_labels ,
526-         temperature ,
527-         "none" ,
528-     )
529-     return  Glm4vCausalLMOutputForPPO (
530-         log_probs = log_probs ,
531-         entropy = entropy ,
532-         hidden_states = outputs .hidden_states ,
533-     )
0 commit comments