@@ -2697,7 +2697,7 @@ def _dola_decoding(
2697
2697
)
2698
2698
2699
2699
# .float() is needed to retain precision for later logits manipulations
2700
- final_layer_next_token_logits = outputs .logits [:, - 1 , :].detach ().clone (). float ( )
2700
+ final_layer_next_token_logits = outputs .logits [:, - 1 , :].detach ().to ( copy = True , dtype = torch . float32 )
2701
2701
final_logits = outputs .logits [:, - 1 , :].float ()
2702
2702
candidate_premature_logits = {}
2703
2703
for candidate_premature_layer in candidate_premature_layers :
@@ -2885,11 +2885,12 @@ def _contrastive_search(
2885
2885
last_hidden_states = outputs .hidden_states [- 1 ]
2886
2886
2887
2887
# next logit for contrastive search to select top-k candidate tokens
2888
- # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
2888
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
2889
2889
# (the clone itself is always small)
2890
- # .float() is needed to retain precision for later logits manipulations
2891
- logit_for_next_step = outputs .logits [:, - 1 , :].clone ().float ()
2892
- logit_for_next_step = logit_for_next_step .to (input_ids .device )
2890
+ # torch.float32 is needed to retain precision for later logits manipulations
2891
+ logit_for_next_step = outputs .logits [:, - 1 , :].to (
2892
+ copy = True , dtype = torch .float32 , device = input_ids .device
2893
+ )
2893
2894
2894
2895
model_kwargs = self ._update_model_kwargs_for_generation (
2895
2896
outputs ,
@@ -3297,10 +3298,9 @@ def _sample(
3297
3298
if synced_gpus and this_peer_finished :
3298
3299
continue
3299
3300
3300
- # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
3301
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
3301
3302
# (the clone itself is always small)
3302
- next_token_logits = outputs .logits [:, - 1 , :].clone ().float ()
3303
- next_token_logits = next_token_logits .to (input_ids .device )
3303
+ next_token_logits = outputs .logits [:, - 1 , :].to (copy = True , dtype = torch .float32 , device = input_ids .device )
3304
3304
3305
3305
# pre-process distribution
3306
3306
next_token_scores = logits_processor (input_ids , next_token_logits )
@@ -3768,8 +3768,8 @@ def _beam_search(
3768
3768
if synced_gpus and this_peer_finished :
3769
3769
continue
3770
3770
3771
- logits = model_outputs . logits [:, - 1 , :]. clone (). float () # Clone is needed to avoid keeping a hanging ref
3772
- logits = logits .to (input_ids .device )
3771
+ # Copy is needed to avoid keeping a hanging ref
3772
+ logits = model_outputs . logits [:, - 1 , :] .to (copy = True , dtype = torch . float32 , device = input_ids .device )
3773
3773
3774
3774
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
3775
3775
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
@@ -4045,10 +4045,9 @@ def _group_beam_search(
4045
4045
if output_scores :
4046
4046
processed_score = torch .zeros_like (outputs .logits [:, - 1 , :])
4047
4047
if output_logits :
4048
- # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4048
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4049
4049
# (the clone itself is always small)
4050
- raw_logit_score = outputs .logits [:, - 1 , :].clone ()
4051
- raw_logit_score = raw_logit_score .to (input_ids .device )
4050
+ raw_logit_score = outputs .logits [:, - 1 , :].to (copy = True , device = input_ids .device )
4052
4051
4053
4052
for beam_group_idx in range (num_beam_groups ):
4054
4053
group_start_idx = beam_group_idx * num_sub_beams
@@ -4067,8 +4066,9 @@ def _group_beam_search(
4067
4066
# select outputs of beams of current group only
4068
4067
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
4069
4068
# .float() is needed to retain precision for later logits manipulations
4070
- next_token_logits = outputs .logits [batch_group_indices , - 1 , :].float ()
4071
- next_token_logits = next_token_logits .to (input_ids .device )
4069
+ next_token_logits = outputs .logits [batch_group_indices , - 1 , :].to (
4070
+ dtype = torch .float32 , device = input_ids .device
4071
+ )
4072
4072
4073
4073
next_token_scores = nn .functional .log_softmax (
4074
4074
next_token_logits , dim = - 1
@@ -4322,11 +4322,10 @@ def _constrained_beam_search(
4322
4322
cur_len = cur_len + 1
4323
4323
continue
4324
4324
4325
- # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4325
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4326
4326
# (the clone itself is always small)
4327
4327
# .float() is needed to retain precision for later logits manipulations
4328
- next_token_logits = outputs .logits [:, - 1 , :].clone ().float ()
4329
- next_token_logits = next_token_logits .to (input_ids .device )
4328
+ next_token_logits = outputs .logits [:, - 1 , :].to (copy = True , dtype = torch .float32 , device = input_ids .device )
4330
4329
next_token_scores = nn .functional .log_softmax (
4331
4330
next_token_logits , dim = - 1
4332
4331
) # (batch_size * num_beams, vocab_size)
@@ -4574,8 +4573,9 @@ def _assisted_decoding(
4574
4573
4575
4574
# 2.3. Process the new logits
4576
4575
# .float() is needed to retain precision for later logits manipulations
4577
- new_logits = outputs .logits [:, - candidate_length - 1 :].float () # excludes the input prompt if present
4578
- new_logits = new_logits .to (input_ids .device )
4576
+ new_logits = outputs .logits [:, - candidate_length - 1 :].to (
4577
+ dtype = torch .float32 , device = input_ids .device
4578
+ ) # excludes the input prompt if present
4579
4579
next_token_logits = new_logits .clone ()
4580
4580
if len (logits_processor ) > 0 :
4581
4581
for i in range (candidate_length + 1 ):
0 commit comments