Skip to content

Commit 6efc0d5

Browse files
cyyeverzucchini-nlp
authored andcommitted
Use torch.expm1 (huggingface#36995)
1 parent cf55409 commit 6efc0d5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/transformers/models/seamless_m4t/modeling_seamless_m4t.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,7 @@ def forward(
25782578
lang = self.language_embedding(lang_id).transpose(1, 2)
25792579

25802580
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
2581-
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
2581+
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
25822582
# B x C x T
25832583
if hidden_states.size(0) == 1:
25842584
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)

src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,7 +2292,7 @@ def forward(
22922292

22932293
# predict duration
22942294
log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask)
2295-
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
2295+
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
22962296
dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0)
22972297

22982298
# upsample char hidden states according to predicted duration
@@ -2854,7 +2854,7 @@ def forward(
28542854
lang = self.language_embedding(lang_id).transpose(1, 2)
28552855

28562856
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
2857-
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
2857+
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
28582858
# B x C x T
28592859
if hidden_states.size(0) == 1:
28602860
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)

0 commit comments

Comments
 (0)