Skip to content

update pit loss and spk loss calculation #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eend/pytorch_backend/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def save_spkv_lab(args):
vec = outputs[i+1][0].cpu().detach().numpy()
lab = chunk_data[2][sigma[i]]
all_outputs.append(vec)
all_labels.append(lab)
all_labels.append(lab.item())

orgdata_all_n_speakers = data_set.get_allnspk()
# Generate spkidx_tbl to convert speaker ID
Expand Down
185 changes: 72 additions & 113 deletions eend/pytorch_backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,81 +18,39 @@
B: mini-batch size
"""


def pit_loss(pred, label):
"""
Permutation-invariant training (PIT) cross entropy loss function.

def batch_pit_loss(outputs, labels, ilens=None):
""" calculate the batch pit loss parallelly
Args:
pred: (T,C)-shaped pre-activation values
label: (T,C)-shaped labels in {0,1}

outputs (torch.Tensor): B x T x C
labels (torch.Tensor): B x T x C
ilens (torch.Tensor): B
Returns:
min_loss: (1,)-shape mean cross entropy
label_perms[min_index]: permutated labels
sigma: permutation
perm (torch.Tensor): permutation for outputs (Batch, num_spk)
loss
"""

device = pred.device
T = len(label)
C = label.shape[-1]
label_perms_indices = [
list(p) for p in permutations(range(C))]
P = len(label_perms_indices)
perm_mat = torch.zeros(P, T, C, C).to(device)

for i, p in enumerate(label_perms_indices):
perm_mat[i, :, torch.arange(label.shape[-1]), p] = 1

x = torch.unsqueeze(torch.unsqueeze(label, 0), -1).to(device)
y = torch.arange(P * T * C).view(P, T, C, 1).to(device)

broadcast_label = torch.broadcast_tensors(x, y)[0]
allperm_label = torch.matmul(
perm_mat, broadcast_label
).squeeze(-1)

x = torch.unsqueeze(pred, 0)
y = torch.arange(P * T).view(P, T, 1)
broadcast_pred = torch.broadcast_tensors(x, y)[0]

# broadcast_pred: (P, T, C)
# allperm_label: (P, T, C)
losses = F.binary_cross_entropy_with_logits(
broadcast_pred,
allperm_label,
reduction='none')
mean_losses = torch.mean(torch.mean(losses, dim=1), dim=1)
min_loss = torch.min(mean_losses) * len(label)
min_index = torch.argmin(mean_losses)
sigma = list(permutations(range(label.shape[-1])))[min_index]

return min_loss, allperm_label[min_index], sigma


def batch_pit_loss(ys, ts, ilens=None):
"""
PIT loss over mini-batch.

Args:
ys: B-length list of predictions
ts: B-length list of labels

Returns:
loss: (1,)-shape mean cross entropy over mini-batch
sigmas: B-length list of permutation
"""
if ilens is None:
ilens = [t.shape[0] for t in ts]
mask, scale = 1.0, outputs.shape[1]
else:
scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
mask = outputs.new_zeros(outputs.size()[:-1])
for i, chunk_len in enumerate(ilens):
mask[i, :chunk_len] += 1.0
mask /= scale

def loss_func(output, label):
return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)

loss_w_labels_w_sigmas = [pit_loss(y[:ilen, :], t[:ilen, :])
for (y, t, ilen) in zip(ys, ts, ilens)]
losses, _, sigmas = zip(*loss_w_labels_w_sigmas)
loss = torch.sum(torch.stack(losses))
n_frames = np.sum([ilen for ilen in ilens])
loss = loss / n_frames
def pair_loss(outputs, labels, permutation):
return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)

return loss, sigmas
device = outputs.device
num_spk = outputs.shape[-1]
all_permutations = list(permutations(range(num_spk)))
losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
loss, perm = torch.min(losses, dim=1)
perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
return torch.mean(loss), perm


def fix_state_dict(state_dict):
Expand Down Expand Up @@ -122,15 +80,13 @@ def forward(self, inputs):
def review(self, inputs, outputs):
ys = outputs["prediction"]
spksvecs = outputs["spksvecs"]
spksvecs = list(zip(*spksvecs))
ts = inputs[1]
ss = inputs[2]
ns = inputs[3]
ilens = inputs[4]
ilens = [ilen.item() for ilen in ilens]

pit_loss, sigmas = batch_pit_loss(ys, ts, ilens)
ss = [[i.item() for i in s] for s in ss]
if pit_loss.requires_grad:
spk_loss = self.batch_spk_loss(
spksvecs, ys, ts, ss, sigmas, ns, ilens)
Expand All @@ -144,49 +100,52 @@ def review(self, inputs, outputs):
scalars={'alpha': alpha})

def batch_spk_loss(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
spksvecs = [[spkvec[:ilen] for spkvec in spksvec]
for spksvec, ilen in zip(spksvecs, ilens)]
loss = torch.stack(
[self.spk_loss(spksvec, y[:ilen], t[:ilen], s, sigma, n)
for(spksvec, y, t, s, sigma, n, ilen)
in zip(spksvecs, ys, ts, ss, sigmas, ns, ilens)])
loss = torch.mean(loss)

return loss

def spk_loss(self, spksvec, y, t, s, sigma, n):
embeds = self.net.embed(n).squeeze()
z = torch.sigmoid(y.transpose(1, 0))

losses = []
for spkid, spkvec in enumerate(spksvec):
norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
# Normalize speaker vectors before weighted average
spkvec = torch.mul(
spkvec.transpose(1, 0), norm_spkvec_inv).transpose(1, 0)
wavg_spkvec = torch.mul(
spkvec.transpose(1, 0), z[spkid]).transpose(1, 0)
sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
nmz_wavg_spkvec = torch.unsqueeze(nmz_wavg_spkvec, 0)
norm_embeds_inv = 1.0 / torch.norm(embeds, dim=1)
embeds = torch.mul(
embeds.transpose(1, 0), norm_embeds_inv).transpose(1, 0)
dist = torch.cdist(nmz_wavg_spkvec, embeds)[0]
d = torch.add(
torch.clamp(
self.net.alpha,
min=sys.float_info.epsilon) * torch.pow(dist, 2),
self.net.beta)

round_t = torch.round(t.transpose(1, 0)[sigma[spkid]])
if torch.sum(round_t) > 0:
loss = -F.log_softmax(-d, 0)[s[sigma[spkid]]]
else:
loss = torch.tensor(0.0).to(y.device)
losses.append(loss)

return torch.mean(torch.stack(losses))
'''
spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
ys (torch.Tensor): B x T x 3
ts (torch.Tensor): B x T x 3
ss (torch.Tensor): B x 3
sigmas (torch.Tensor): B x 3
ns (torch.Tensor): B x total_spk_num x 1
ilens (List): B
'''
chunk_spk_num = len(spksvecs) # 3

len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
for i, len_val in enumerate(ilens):
len_mask[i,:len_val] += 1.0
ts = ts * len_mask.unsqueeze(-1)
len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T

spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
# Normalize speaker vectors before weighted average
spk_vecs = F.normalize(spk_vecs, dim=-1)

ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1

weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
weight_spk_vec *= len_mask.unsqueeze(-1)
sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
norm_spk_vec = F.normalize(sum_spk_vec, dim=1)

embeds = F.normalize(self.net.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
logits = -1.0 * torch.add(torch.clamp(self.net.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.net.beta)
label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
label[label==-1] = 0
valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3

valid_spk_loss_num = torch.sum(valid_spk_mask).item()
if valid_spk_loss_num > 0:
# uncomment the line below, nly divide the number of samples contain speakers
# loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
# uncomment the line below, the loss result is same as batch_spk_loss
loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
return torch.sum(loss)
else:
return torch.tensor(0.0).to(ys.device)


class TransformerDiarization(nn.Module):
Expand Down