-
Notifications
You must be signed in to change notification settings - Fork 126
Open
Description
Hi,
I am facing segmentation fault issue with large input tensor.
Test code:
import torch
from warprnnt_pytorch import RNNTLoss
rnnt_loss = RNNTLoss()
acts = torch.ones(128,256,129,1024, requires_grad=True)
labels = torch.ones(128,128)
act_length = torch.ones(128) * 256
label_length = torch.ones(128) * 128
if acts.dtype != torch.float:
acts = acts.float()
if labels.dtype != torch.int32:
labels = labels.int()
if act_length.dtype != torch.int32:
act_length = act_length.int()
if label_length.dtype != torch.int32:
label_length = label_length.int()
loss = rnnt_loss(acts, labels, act_length, label_length)
print(loss)
loss.backward()
How can I fix this issue?
Thanks!
Metadata
Metadata
Assignees
Labels
No labels