-
Notifications
You must be signed in to change notification settings - Fork 126
Open
Description
import torch
import numpy as np
from warprnnt_pytorch import RNNTLoss
acts = np.random.rand(2,2,3,5)
labels = np.array([[1, 2],[2,2]])
act_length = np.array([2,2])
label_length = np.array([2,2])
blank = 1
reduction = 'mean'
print('----------------------------------------')
rnnt_loss = RNNTLoss(blank,reduction)
acts_t = torch.tensor(acts,dtype=torch.float)
acts_t.requires_grad = True
labels_t = torch.tensor(labels,dtype=torch.int)
act_length_t = torch.tensor(act_length,dtype=torch.int)
label_length_t = torch.tensor(label_length,dtype=torch.int)
loss = rnnt_loss(acts_t,labels_t,act_length_t,label_length_t)
loss.backward()
print(loss)
print(acts_t.grad)
print('-----------------------------------------')
rnnt_loss = RNNTLoss(blank,reduction).cuda()
acts_t = torch.tensor(acts,dtype=torch.float).cuda()
acts_t.requires_grad = True
labels_t = torch.tensor(labels,dtype=torch.int).cuda()
act_length_t = torch.tensor(act_length,dtype=torch.int).cuda()
label_length_t = torch.tensor(label_length,dtype=torch.int).cuda()
loss = rnnt_loss(acts_t,labels_t,act_length_t,label_length_t)
loss.backward()
print(loss)
print(acts_t.grad)
print('-----------------------------------------')
----------------------------------------
tensor([4.6379], grad_fn=<_RNNTBackward>)
tensor([[[[ 0.0550, -0.2406, 0.0716, 0.0589, 0.0550],
[ 0.0492, -0.1081, -0.0603, 0.0419, 0.0774],
[ 0.0290, -0.1254, 0.0351, 0.0291, 0.0322]],
[[ 0.0346, -0.1494, 0.0226, 0.0539, 0.0383],
[ 0.0726, 0.0889, -0.2694, 0.0452, 0.0627],
[ 0.0730, -0.4038, 0.1470, 0.0579, 0.1259]]],
[[[ 0.0749, -0.0353, -0.2435, 0.1090, 0.0948],
[ 0.0529, -0.1224, -0.0426, 0.0533, 0.0587],
[ 0.0290, -0.1383, 0.0383, 0.0461, 0.0249]],
[[ 0.0163, 0.0226, -0.0751, 0.0122, 0.0240],
[ 0.0655, 0.0462, -0.2269, 0.0494, 0.0658],
[ 0.0755, -0.3531, 0.1155, 0.0999, 0.0623]]]])
-----------------------------------------
tensor([4.6379], device='cuda:0', grad_fn=<_RNNTBackward>)
tensor([[[[ 0.0885, -0.3869, 0.1152, 0.0947, 0.0885],
[ 0.0492, -0.1081, -0.0603, 0.0419, 0.0774],
[ 0.0290, -0.1254, 0.0351, 0.0291, 0.0322]],
[[ 0.0346, -0.1494, 0.0226, 0.0539, 0.0383],
[ 0.0726, 0.0889, -0.2694, 0.0452, 0.0627],
[ 0.0730, -0.4038, 0.1470, 0.0579, 0.1259]]],
[[[ 0.0749, -0.0353, -0.2435, 0.1090, 0.0948],
[ 0.0529, -0.1224, -0.0426, 0.0533, 0.0587],
[ 0.0290, -0.1383, 0.0383, 0.0461, 0.0249]],
[[ 0.0163, 0.0226, -0.0751, 0.0122, 0.0240],
[ 0.0655, 0.0462, -0.2269, 0.0494, 0.0658],
[ 0.0755, -0.3531, 0.1155, 0.0999, 0.0623]]]], device='cuda:0')
-----------------------------------------
Metadata
Metadata
Assignees
Labels
No labels