Skip to content

Commit d8ecc7b

Browse files
cyyeverzucchini-nlp
authored andcommitted
Remove extra tensor clone in PyTorch code (huggingface#36748)
* Use detach().clone() * Eliminate continuous() * Merge clone and other calls with to * Merge clone and other calls with to
1 parent 5faec28 commit d8ecc7b

File tree

6 files changed

+37
-37
lines changed

6 files changed

+37
-37
lines changed

src/transformers/generation/utils.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,7 +2697,7 @@ def _dola_decoding(
26972697
)
26982698

26992699
# .float() is needed to retain precision for later logits manipulations
2700-
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
2700+
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32)
27012701
final_logits = outputs.logits[:, -1, :].float()
27022702
candidate_premature_logits = {}
27032703
for candidate_premature_layer in candidate_premature_layers:
@@ -2885,11 +2885,12 @@ def _contrastive_search(
28852885
last_hidden_states = outputs.hidden_states[-1]
28862886

28872887
# next logit for contrastive search to select top-k candidate tokens
2888-
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
2888+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
28892889
# (the clone itself is always small)
2890-
# .float() is needed to retain precision for later logits manipulations
2891-
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
2892-
logit_for_next_step = logit_for_next_step.to(input_ids.device)
2890+
# torch.float32 is needed to retain precision for later logits manipulations
2891+
logit_for_next_step = outputs.logits[:, -1, :].to(
2892+
copy=True, dtype=torch.float32, device=input_ids.device
2893+
)
28932894

28942895
model_kwargs = self._update_model_kwargs_for_generation(
28952896
outputs,
@@ -3297,10 +3298,9 @@ def _sample(
32973298
if synced_gpus and this_peer_finished:
32983299
continue
32993300

3300-
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
3301+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
33013302
# (the clone itself is always small)
3302-
next_token_logits = outputs.logits[:, -1, :].clone().float()
3303-
next_token_logits = next_token_logits.to(input_ids.device)
3303+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
33043304

33053305
# pre-process distribution
33063306
next_token_scores = logits_processor(input_ids, next_token_logits)
@@ -3768,8 +3768,8 @@ def _beam_search(
37683768
if synced_gpus and this_peer_finished:
37693769
continue
37703770

3771-
logits = model_outputs.logits[:, -1, :].clone().float() # Clone is needed to avoid keeping a hanging ref
3772-
logits = logits.to(input_ids.device)
3771+
# Copy is needed to avoid keeping a hanging ref
3772+
logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
37733773

37743774
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
37753775
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
@@ -4045,10 +4045,9 @@ def _group_beam_search(
40454045
if output_scores:
40464046
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
40474047
if output_logits:
4048-
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4048+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
40494049
# (the clone itself is always small)
4050-
raw_logit_score = outputs.logits[:, -1, :].clone()
4051-
raw_logit_score = raw_logit_score.to(input_ids.device)
4050+
raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
40524051

40534052
for beam_group_idx in range(num_beam_groups):
40544053
group_start_idx = beam_group_idx * num_sub_beams
@@ -4067,8 +4066,9 @@ def _group_beam_search(
40674066
# select outputs of beams of current group only
40684067
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
40694068
# .float() is needed to retain precision for later logits manipulations
4070-
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
4071-
next_token_logits = next_token_logits.to(input_ids.device)
4069+
next_token_logits = outputs.logits[batch_group_indices, -1, :].to(
4070+
dtype=torch.float32, device=input_ids.device
4071+
)
40724072

40734073
next_token_scores = nn.functional.log_softmax(
40744074
next_token_logits, dim=-1
@@ -4322,11 +4322,10 @@ def _constrained_beam_search(
43224322
cur_len = cur_len + 1
43234323
continue
43244324

4325-
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
4325+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
43264326
# (the clone itself is always small)
43274327
# .float() is needed to retain precision for later logits manipulations
4328-
next_token_logits = outputs.logits[:, -1, :].clone().float()
4329-
next_token_logits = next_token_logits.to(input_ids.device)
4328+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
43304329
next_token_scores = nn.functional.log_softmax(
43314330
next_token_logits, dim=-1
43324331
) # (batch_size * num_beams, vocab_size)
@@ -4574,8 +4573,9 @@ def _assisted_decoding(
45744573

45754574
# 2.3. Process the new logits
45764575
# .float() is needed to retain precision for later logits manipulations
4577-
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
4578-
new_logits = new_logits.to(input_ids.device)
4576+
new_logits = outputs.logits[:, -candidate_length - 1 :].to(
4577+
dtype=torch.float32, device=input_ids.device
4578+
) # excludes the input prompt if present
45794579
next_token_logits = new_logits.clone()
45804580
if len(logits_processor) > 0:
45814581
for i in range(candidate_length + 1):

src/transformers/integrations/higgs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
446446

447447
device = weight.device
448448
dtype = weight.dtype
449-
weight = weight.clone().float()
449+
weight = weight.to(copy=True, dtype=torch.float32)
450450
# Pad to Hadamard transform size
451451
weight = pad_to_block(weight, [1], hadamard_size)
452452

src/transformers/models/deprecated/jukebox/modeling_jukebox.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,12 +2205,12 @@ def forward_tokens(
22052205
loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
22062206

22072207
metrics = {
2208-
"bpd": next_token_prediction_loss.clone().detach(),
2209-
"encoder_loss": encoder_loss.clone().detach(),
2210-
"next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
2208+
"bpd": next_token_prediction_loss.detach().clone(),
2209+
"encoder_loss": encoder_loss.detach().clone(),
2210+
"next_token_prediction_loss": next_token_prediction_loss.detach().clone(),
22112211
}
22122212
if get_preds:
2213-
metrics["preds"] = preds.clone().detach()
2213+
metrics["preds"] = preds.detach().clone()
22142214
if get_attn_weights:
22152215
saved_attn_weights = self.prior.transformer.saved_attn_weights
22162216
self.prior.transformer.set_record_attn(False)

src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
148148
w3 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w3"]
149149

150150
experts_w1 = [
151-
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
151+
w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
152152
for expert_idx in range(num_local_experts)
153153
]
154154

@@ -157,16 +157,16 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
157157
state_dict[expert_key + ".weight"] = expert_block.clone()
158158

159159
experts_w2 = [
160-
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
160+
w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
161161
for expert_idx in range(num_local_experts)
162162
]
163163

164164
for idx, expert_block in enumerate(experts_w2):
165165
expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w2"
166-
state_dict[expert_key + ".weight"] = expert_block.T.clone().contiguous()
166+
state_dict[expert_key + ".weight"] = expert_block.T.clone(memory_format=torch.contiguous_format)
167167

168168
experts_w3 = [
169-
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
169+
w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].clone(memory_format=torch.contiguous_format)
170170
for expert_idx in range(num_local_experts)
171171
]
172172

src/transformers/models/videomae/modeling_videomae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def forward(self, pixel_values, bool_masked_pos):
131131
embeddings = self.patch_embeddings(pixel_values)
132132

133133
# add position embeddings
134-
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
134+
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).detach().clone()
135135
# only keep visible patches
136136
# ~bool_masked_pos means visible
137137
if bool_masked_pos is not None:
@@ -856,7 +856,7 @@ def forward(
856856
if bool_masked_pos is None:
857857
raise ValueError("One must provided a boolean mask ")
858858
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
859-
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
859+
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).detach().clone()
860860
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
861861
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
862862

src/transformers/pytorch_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0)
7373
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
7474
"""
7575
index = index.to(layer.weight.device)
76-
W = layer.weight.index_select(dim, index).clone().detach()
76+
W = layer.weight.index_select(dim, index).detach().clone()
7777
if layer.bias is not None:
7878
if dim == 1:
79-
b = layer.bias.clone().detach()
79+
b = layer.bias.detach().clone()
8080
else:
81-
b = layer.bias[index].clone().detach()
81+
b = layer.bias[index].detach().clone()
8282
new_size = list(layer.weight.size())
8383
new_size[dim] = len(index)
8484
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
@@ -137,11 +137,11 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
137137
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
138138
"""
139139
index = index.to(layer.weight.device)
140-
W = layer.weight.index_select(dim, index).clone().detach()
140+
W = layer.weight.index_select(dim, index).detach().clone()
141141
if dim == 0:
142-
b = layer.bias.clone().detach()
142+
b = layer.bias.detach().clone()
143143
else:
144-
b = layer.bias[index].clone().detach()
144+
b = layer.bias[index].detach().clone()
145145
new_size = list(layer.weight.size())
146146
new_size[dim] = len(index)
147147
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)

0 commit comments

Comments
 (0)