Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_get_transformer_encoder():
conv_config = sockeye.encoder.ConvolutionalEmbeddingConfig(num_embed=6, add_positional_encoding=True)
config = sockeye.transformer.TransformerConfig(model_size=20,
attention_heads=10,
feed_forward_num_hidden=30,
act_type='test_act',
num_layers=40,
dropout_attention=1.0,
dropout_act=2.0,
dropout_prepost=3.0,
positional_embedding_type=C.LEARNED_POSITIONAL_EMBEDDING,
preprocess_sequence='test_pre',
postprocess_sequence='test_post',
max_seq_len_source=50,
max_seq_len_target=60,
conv_config=conv_config, dtype='float16')
encoder = sockeye.encoder.get_transformer_encoder(config, prefix='test_')
assert type(encoder) == sockeye.encoder.EncoderSequence
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Tuple, Union, Type
import mxnet as mx
from . import constants as C
from . import convolution
from . import encoder
from . import layers
from . import rnn
from . import rnn_attention
from . import transformer
from . import utils
from .config import Config
logger = logging.getLogger(__name__)
DecoderConfig = Union['RecurrentDecoderConfig', transformer.TransformerConfig, 'ConvolutionalDecoderConfig']
def get_decoder(config: DecoderConfig, prefix: str = '') -> 'Decoder':
return Decoder.get_decoder(config, prefix)
class Decoder(ABC):
"""
Generic decoder interface.
A decoder needs to implement code to decode a target sequence known in advance (decode_sequence),
and code to decode a single word given its decoder state (decode_step).
The latter is typically used for inference graphs in beam search.
For the inference module to be able to keep track of decoder's states
a decoder provides methods to return initial states (init_states), state variables and their shapes.
:param dtype: Data type.
:param encoder_num_hidden: Number of hidden units of the Encoder.
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:return: The config for the decoder.
"""
_, decoder_num_layers = args.num_layers
_, num_embed_target = args.num_embed
config_decoder = None # type: Optional[Config]
if args.decoder == C.TRANSFORMER_TYPE:
if args.decoder_only:
raise NotImplementedError()
_, decoder_transformer_preprocess = args.transformer_preprocess
_, decoder_transformer_postprocess = args.transformer_postprocess
config_decoder = transformer.TransformerConfig(
model_size=args.transformer_model_size[1],
attention_heads=args.transformer_attention_heads[1],
feed_forward_num_hidden=args.transformer_feed_forward_num_hidden[1],
act_type=args.transformer_activation_type,
num_layers=decoder_num_layers,
dropout_attention=args.transformer_dropout_attention,
dropout_act=args.transformer_dropout_act,
dropout_prepost=args.transformer_dropout_prepost,
positional_embedding_type=args.transformer_positional_embedding_type,
preprocess_sequence=decoder_transformer_preprocess,
postprocess_sequence=decoder_transformer_postprocess,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
conv_config=None,
lhuc=args.lhuc is not None and (C.LHUC_DECODER in args.lhuc or C.LHUC_ALL in args.lhuc))
return seg_embedding, encoded_data_length, encoded_seq_len
def get_num_hidden(self) -> int:
"""
Return the representation size of this encoder.
"""
return self.output_dim
def get_encoded_seq_len(self, seq_len: int) -> int:
"""
Returns the size of the encoded sequence.
"""
return int(ceil(seq_len / self.pool_stride))
EncoderConfig = Union[RecurrentEncoderConfig, transformer.TransformerConfig, ConvolutionalEncoderConfig,
EmptyEncoderConfig]
if ImageEncoderConfig is not None:
EncoderConfig = Union[EncoderConfig, ImageEncoderConfig] # type: ignore
encoder_num_hidden = args.cnn_num_hidden
else:
encoder_num_hidden = args.rnn_num_hidden
config_encoder = encoder.EmptyEncoderConfig(num_embed=num_embed_source,
num_hidden=encoder_num_hidden)
elif args.encoder in (C.TRANSFORMER_TYPE, C.TRANSFORMER_WITH_CONV_EMBED_TYPE):
encoder_transformer_preprocess, _ = args.transformer_preprocess
encoder_transformer_postprocess, _ = args.transformer_postprocess
encoder_transformer_model_size = args.transformer_model_size[0]
total_source_factor_size = sum(args.source_factors_num_embed)
if args.source_factors_combine == C.SOURCE_FACTORS_COMBINE_CONCAT and total_source_factor_size > 0:
logger.info("Encoder transformer-model-size adjusted to account for source factor embeddings: %d -> %d" % (
encoder_transformer_model_size, num_embed_source + total_source_factor_size))
encoder_transformer_model_size = num_embed_source + total_source_factor_size
config_encoder = transformer.TransformerConfig(
model_size=encoder_transformer_model_size,
attention_heads=args.transformer_attention_heads[0],
feed_forward_num_hidden=args.transformer_feed_forward_num_hidden[0],
act_type=args.transformer_activation_type,
num_layers=encoder_num_layers,
dropout_attention=args.transformer_dropout_attention,
dropout_act=args.transformer_dropout_act,
dropout_prepost=args.transformer_dropout_prepost,
positional_embedding_type=args.transformer_positional_embedding_type,
preprocess_sequence=encoder_transformer_preprocess,
postprocess_sequence=encoder_transformer_postprocess,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
conv_config=config_conv,
lhuc=args.lhuc is not None and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc))
encoder_num_hidden = encoder_transformer_model_size
:param encoder_num_hidden: Number of hidden units of the Encoder.
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:param num_embed_target: The size of the source embedding.
:return: The config for the decoder.
"""
_, decoder_num_layers = args.num_layers
config_decoder = None # type: Optional[Config]
if args.decoder == C.TRANSFORMER_TYPE:
if args.decoder_only:
raise NotImplementedError()
_, decoder_transformer_preprocess = args.transformer_preprocess
_, decoder_transformer_postprocess = args.transformer_postprocess
config_decoder = transformer.TransformerConfig(
model_size=args.transformer_model_size[1],
attention_heads=args.transformer_attention_heads[1],
feed_forward_num_hidden=args.transformer_feed_forward_num_hidden[1],
act_type=args.transformer_activation_type,
num_layers=decoder_num_layers,
dropout_attention=args.transformer_dropout_attention,
dropout_act=args.transformer_dropout_act,
dropout_prepost=args.transformer_dropout_prepost,
positional_embedding_type=args.transformer_positional_embedding_type,
preprocess_sequence=decoder_transformer_preprocess,
postprocess_sequence=decoder_transformer_postprocess,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
conv_config=None,
lhuc=args.lhuc is not None and (C.LHUC_DECODER in args.lhuc or C.LHUC_ALL in args.lhuc))
def get_encoder(config: 'EncoderConfig', prefix: str = '') -> 'Encoder':
if isinstance(config, RecurrentEncoderConfig):
return get_recurrent_encoder(config, prefix)
elif isinstance(config, transformer.TransformerConfig):
return get_transformer_encoder(config, prefix)
elif isinstance(config, ConvolutionalEncoderConfig):
return get_convolutional_encoder(config, prefix)
elif isinstance(config, EmptyEncoderConfig):
return EncoderSequence([EmptyEncoder(config)], config.dtype)
else:
from .image_captioning.encoder import ImageLoadedCnnEncoderConfig, \
get_image_cnn_encoder
if isinstance(config, ImageLoadedCnnEncoderConfig):
return get_image_cnn_encoder(config)
else:
raise ValueError("Unsupported encoder configuration")