Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
in_masks = make_non_pad_mask(ilens) # (B, T_in)
out_masks = make_non_pad_mask(olens) # (B, T_out)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
Args:
ilens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loass value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
logits (Tensor): Batch of stop logits (B, Lmax).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
Tensor: Binary cross entropy loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
ys = ys.masked_select(mask)
after_outs = after_outs.masked_select(mask)
before_outs = before_outs.masked_select(mask)
labels = labels.masked_select(mask[:, :, 0])
logits = logits.masked_select(mask[:, :, 0])
# calculate loss
l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
bce_loss = F.binary_cross_entropy_with_logits(
logits, labels, pos_weight=torch.tensor(self.bce_pos_weight, device=ys.device))
return l1_loss, mse_loss, bce_loss
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
# forward propagation
before_outs, after_outs, ds, d_outs = self._forward(xs, ilens, ys, olens, spembs=spembs, is_inference=False)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
# apply mask to remove padded part
if self.use_masking:
in_masks = make_non_pad_mask(ilens).to(xs.device)
d_outs = d_outs.masked_select(in_masks)
ds = ds.masked_select(in_masks)
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
before_outs = before_outs.masked_select(out_masks)
after_outs = after_outs.masked_select(out_masks)
ys = ys.masked_select(out_masks)
# calculate loss
if self.postnet is None:
l1_loss = F.l1_loss(after_outs, ys) # after_outs is the same as before_outs if postnet is None
else:
l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
duration_loss = self.duration_criterion(d_outs, ds)
loss = l1_loss + duration_loss
report_keys = [
{"l1_loss": l1_loss.item()},
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
in_masks = make_non_pad_mask(ilens) # (B, T_in)
out_masks = make_non_pad_mask(olens) # (B, T_out)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
def forward(self, cbhg_outs, spcs, olens):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
spcs = spcs.masked_select(mask)
cbhg_outs = cbhg_outs.masked_select(mask)
# calculate loss
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
return cbhg_l1_loss, cbhg_mse_loss
>>> olens = [5, 3]
>>> self._source_to_target_mask(ilens)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1)