Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, n_head, n_feat, dropout_rate):
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def __init__(self):
super().__init__()
self.att1 = MultiHeadedAttention(2, 10, 0.0)
self.att2 = AttAdd(10, 20, 15)
self.desired = defaultdict(list)
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
self.forward(xs_pad, ilens, ys_pad, ys_pad_src)
ret = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention) and m.attn is not None: # skip MHA for submodules
ret[name] = m.attn.cpu().numpy()
return ret
lambda: EncoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after
)
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
self.forward(xs_pad, ilens, ys_pad)
ret = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
ret[name] = m.attn.cpu().numpy()
return ret
Returns:
dict: Dict of attention weights and outputs.
"""
with torch.no_grad():
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
# forward propagation
outs = self._forward(xs, ilens, ys, olens, spembs=spembs, is_inference=False)[0]
att_ws_dict = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())]
else:
logging.warning("unknown attention module: " + name)
else:
logging.warning("unknown attention module: " + name)
att_ws_dict[name] = attn
att_ws_dict["predicted_fbank"] = [m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist())]
return att_ws_dict
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
self.forward(xs_pad, ilens, ys_pad)
ret = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
ret[name] = m.attn.cpu().numpy()
return ret
from espnet.nets.pytorch_backend.rnn.attentions import NoAtt
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention,
)
class AbsAttention(torch.nn.Module, ABC):
"""A marker class to represent "Attention" object
See also: calculate_all_attentions()
"""
# TODO(kamo): Using tricky way such as register() to keep espnet/ as it is.
# Each class should inherit the abs class originally.
AbsAttention.register(MultiHeadedAttention)
AbsAttention.register(NoAtt)
AbsAttention.register(AttDot)
AbsAttention.register(AttAdd)
AbsAttention.register(AttLoc)
AbsAttention.register(AttCov)
AbsAttention.register(AttLoc2D)
AbsAttention.register(AttLocRec)
AbsAttention.register(AttCovLoc)
AbsAttention.register(AttMultiHeadDot)
AbsAttention.register(AttMultiHeadAdd)
AbsAttention.register(AttMultiHeadLoc)
AbsAttention.register(AttMultiHeadMultiResLoc)
AbsAttention.register(AttForward)
AbsAttention.register(AttForwardTA)
# update index
idx += 1
# calculate output and stop prob at idx-th step
y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device)
z, z_cache = self.decoder.forward_one_step(ys, y_masks, hs, cache=z_cache) # (B, adim)
outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...]
probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...]
# update next inputs
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim)
# get attention weights
att_ws_ = []
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention) and "src" in name:
att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...]
if idx == 1:
att_ws = att_ws_
else:
# [(#heads, l, T), ...]
att_ws = [torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_)]
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
if not skip_output:
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
# modifiy mod part of output lengths due to reduction factor > 1
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
# store into dict
att_ws_dict = dict()
if keep_tensor:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
att_ws_dict[name] = m.attn
if not skip_output:
att_ws_dict["before_postnet_fbank"] = before_outs
att_ws_dict["after_postnet_fbank"] = after_outs
else:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens_in.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens_in.tolist())]
else: