From 79c7fce6bed30896d9339f4b56d776ee90720615 Mon Sep 17 00:00:00 2001 From: v-zhenchen Date: Tue, 21 Sep 2021 23:36:33 -0700 Subject: [PATCH] update pit loss and spk loss calculation, which can speech up dozens of times --- eend/pytorch_backend/infer.py | 2 +- eend/pytorch_backend/models.py | 185 +++++++++++++-------------------- 2 files changed, 73 insertions(+), 114 deletions(-) diff --git a/eend/pytorch_backend/infer.py b/eend/pytorch_backend/infer.py index d7d694a..2af2934 100755 --- a/eend/pytorch_backend/infer.py +++ b/eend/pytorch_backend/infer.py @@ -352,7 +352,7 @@ def save_spkv_lab(args): vec = outputs[i+1][0].cpu().detach().numpy() lab = chunk_data[2][sigma[i]] all_outputs.append(vec) - all_labels.append(lab) + all_labels.append(lab.item()) orgdata_all_n_speakers = data_set.get_allnspk() # Generate spkidx_tbl to convert speaker ID diff --git a/eend/pytorch_backend/models.py b/eend/pytorch_backend/models.py index a89627b..2a97050 100644 --- a/eend/pytorch_backend/models.py +++ b/eend/pytorch_backend/models.py @@ -18,81 +18,39 @@ B: mini-batch size """ - -def pit_loss(pred, label): - """ - Permutation-invariant training (PIT) cross entropy loss function. - +def batch_pit_loss(outputs, labels, ilens=None): + """ calculate the batch pit loss parallelly Args: - pred: (T,C)-shaped pre-activation values - label: (T,C)-shaped labels in {0,1} - + outputs (torch.Tensor): B x T x C + labels (torch.Tensor): B x T x C + ilens (torch.Tensor): B Returns: - min_loss: (1,)-shape mean cross entropy - label_perms[min_index]: permutated labels - sigma: permutation + perm (torch.Tensor): permutation for outputs (Batch, num_spk) + loss """ - device = pred.device - T = len(label) - C = label.shape[-1] - label_perms_indices = [ - list(p) for p in permutations(range(C))] - P = len(label_perms_indices) - perm_mat = torch.zeros(P, T, C, C).to(device) - - for i, p in enumerate(label_perms_indices): - perm_mat[i, :, torch.arange(label.shape[-1]), p] = 1 - - x = torch.unsqueeze(torch.unsqueeze(label, 0), -1).to(device) - y = torch.arange(P * T * C).view(P, T, C, 1).to(device) - - broadcast_label = torch.broadcast_tensors(x, y)[0] - allperm_label = torch.matmul( - perm_mat, broadcast_label - ).squeeze(-1) - - x = torch.unsqueeze(pred, 0) - y = torch.arange(P * T).view(P, T, 1) - broadcast_pred = torch.broadcast_tensors(x, y)[0] - - # broadcast_pred: (P, T, C) - # allperm_label: (P, T, C) - losses = F.binary_cross_entropy_with_logits( - broadcast_pred, - allperm_label, - reduction='none') - mean_losses = torch.mean(torch.mean(losses, dim=1), dim=1) - min_loss = torch.min(mean_losses) * len(label) - min_index = torch.argmin(mean_losses) - sigma = list(permutations(range(label.shape[-1])))[min_index] - - return min_loss, allperm_label[min_index], sigma - - -def batch_pit_loss(ys, ts, ilens=None): - """ - PIT loss over mini-batch. - - Args: - ys: B-length list of predictions - ts: B-length list of labels - - Returns: - loss: (1,)-shape mean cross entropy over mini-batch - sigmas: B-length list of permutation - """ if ilens is None: - ilens = [t.shape[0] for t in ts] + mask, scale = 1.0, outputs.shape[1] + else: + scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device) + mask = outputs.new_zeros(outputs.size()[:-1]) + for i, chunk_len in enumerate(ilens): + mask[i, :chunk_len] += 1.0 + mask /= scale + + def loss_func(output, label): + return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1) - loss_w_labels_w_sigmas = [pit_loss(y[:ilen, :], t[:ilen, :]) - for (y, t, ilen) in zip(ys, ts, ilens)] - losses, _, sigmas = zip(*loss_w_labels_w_sigmas) - loss = torch.sum(torch.stack(losses)) - n_frames = np.sum([ilen for ilen in ilens]) - loss = loss / n_frames + def pair_loss(outputs, labels, permutation): + return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation) - return loss, sigmas + device = outputs.device + num_spk = outputs.shape[-1] + all_permutations = list(permutations(range(num_spk))) + losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1) + loss, perm = torch.min(losses, dim=1) + perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm) + return torch.mean(loss), perm def fix_state_dict(state_dict): @@ -122,7 +80,6 @@ def forward(self, inputs): def review(self, inputs, outputs): ys = outputs["prediction"] spksvecs = outputs["spksvecs"] - spksvecs = list(zip(*spksvecs)) ts = inputs[1] ss = inputs[2] ns = inputs[3] @@ -130,7 +87,6 @@ def review(self, inputs, outputs): ilens = [ilen.item() for ilen in ilens] pit_loss, sigmas = batch_pit_loss(ys, ts, ilens) - ss = [[i.item() for i in s] for s in ss] if pit_loss.requires_grad: spk_loss = self.batch_spk_loss( spksvecs, ys, ts, ss, sigmas, ns, ilens) @@ -144,49 +100,52 @@ def review(self, inputs, outputs): scalars={'alpha': alpha}) def batch_spk_loss(self, spksvecs, ys, ts, ss, sigmas, ns, ilens): - spksvecs = [[spkvec[:ilen] for spkvec in spksvec] - for spksvec, ilen in zip(spksvecs, ilens)] - loss = torch.stack( - [self.spk_loss(spksvec, y[:ilen], t[:ilen], s, sigma, n) - for(spksvec, y, t, s, sigma, n, ilen) - in zip(spksvecs, ys, ts, ss, sigmas, ns, ilens)]) - loss = torch.mean(loss) - - return loss - - def spk_loss(self, spksvec, y, t, s, sigma, n): - embeds = self.net.embed(n).squeeze() - z = torch.sigmoid(y.transpose(1, 0)) - - losses = [] - for spkid, spkvec in enumerate(spksvec): - norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1) - # Normalize speaker vectors before weighted average - spkvec = torch.mul( - spkvec.transpose(1, 0), norm_spkvec_inv).transpose(1, 0) - wavg_spkvec = torch.mul( - spkvec.transpose(1, 0), z[spkid]).transpose(1, 0) - sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0) - nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec) - nmz_wavg_spkvec = torch.unsqueeze(nmz_wavg_spkvec, 0) - norm_embeds_inv = 1.0 / torch.norm(embeds, dim=1) - embeds = torch.mul( - embeds.transpose(1, 0), norm_embeds_inv).transpose(1, 0) - dist = torch.cdist(nmz_wavg_spkvec, embeds)[0] - d = torch.add( - torch.clamp( - self.net.alpha, - min=sys.float_info.epsilon) * torch.pow(dist, 2), - self.net.beta) - - round_t = torch.round(t.transpose(1, 0)[sigma[spkid]]) - if torch.sum(round_t) > 0: - loss = -F.log_softmax(-d, 0)[s[sigma[spkid]]] - else: - loss = torch.tensor(0.0).to(y.device) - losses.append(loss) - - return torch.mean(torch.stack(losses)) + ''' + spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...] + ys (torch.Tensor): B x T x 3 + ts (torch.Tensor): B x T x 3 + ss (torch.Tensor): B x 3 + sigmas (torch.Tensor): B x 3 + ns (torch.Tensor): B x total_spk_num x 1 + ilens (List): B + ''' + chunk_spk_num = len(spksvecs) # 3 + + len_mask = ys.new_zeros((ys.size()[:-1])) # B x T + for i, len_val in enumerate(ilens): + len_mask[i,:len_val] += 1.0 + ts = ts * len_mask.unsqueeze(-1) + len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T + + spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim + # Normalize speaker vectors before weighted average + spk_vecs = F.normalize(spk_vecs, dim=-1) + + ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T + ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1 + + weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim + weight_spk_vec *= len_mask.unsqueeze(-1) + sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim + norm_spk_vec = F.normalize(sum_spk_vec, dim=1) + + embeds = F.normalize(self.net.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim + dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num + logits = -1.0 * torch.add(torch.clamp(self.net.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.net.beta) + label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3 + label[label==-1] = 0 + valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B + valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3 + + valid_spk_loss_num = torch.sum(valid_spk_mask).item() + if valid_spk_loss_num > 0: + # uncomment the line below, nly divide the number of samples contain speakers + # loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num + # uncomment the line below, the loss result is same as batch_spk_loss + loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0] + return torch.sum(loss) + else: + return torch.tensor(0.0).to(ys.device) class TransformerDiarization(nn.Module):