@@ -1932,24 +1932,23 @@ def execute_model(
1932
1932
1933
1933
if model_input .inputs_embeds is not None :
1934
1934
if self .is_driver_worker :
1935
- sampled = broadcast_tensor_dict (
1936
- {"token_ids" : output .sampled_token_ids })
1937
- else :
1938
- sampled = broadcast_tensor_dict ()
1939
- if sampled ["token_ids" ] is not None :
1940
- sampled_token_embeds = self .model .get_input_embeddings (
1941
- sampled ["token_ids" ].squeeze (1 ))
1942
- if self .is_driver_worker :
1935
+ sampled_token_ids = []
1936
+ valid_outputs = []
1937
+ for sequence_group_output in output .outputs :
1938
+ if len (sequence_group_output .samples ) == 0 :
1939
+ continue
1940
+ assert len (sequence_group_output .samples ) == 1
1941
+ valid_outputs .append (sequence_group_output )
1942
+ sampled_token_ids .append (
1943
+ sequence_group_output .samples [0 ].output_token )
1944
+ if len (sampled_token_ids ) > 0 :
1943
1945
self .sampler .include_gpu_probs_tensor = \
1944
1946
orig_include_gpu_probs
1945
-
1946
- output .sampled_token_embeds = sampled_token_embeds
1947
-
1948
- for token_embed , sequence_group_output in zip (
1949
- output .sampled_token_embeds , output .outputs ):
1950
- assert len (sequence_group_output .samples ) == 1
1951
- sequence_group_output .samples [
1952
- 0 ].output_embed = token_embed
1947
+ sampled_token_embeds = self .model .get_input_embeddings (
1948
+ torch .tensor (sampled_token_ids , device = self .device ))
1949
+ for i , sequence_group_output in enumerate (valid_outputs ):
1950
+ sequence_group_output .samples [0 ].output_embed = \
1951
+ sampled_token_embeds [i ]
1953
1952
1954
1953
if not self .is_driver_worker :
1955
1954
return []
0 commit comments