Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
prior_mean: Mean of the prior.
prior_stddev: Standard deviation of the prior.
l2: Amount of L2 regularization to apply to the deterministic weights.
clip_norm: Gradient clipping norm value for per-time step clipping within
the LSTM cell, and for clipping of all aggregated gradients.
return_sequences: Whether or not to return outputs at each time step from
the LSTM, rather than just the final time step.
"""
super().__init__()
self.hidden_layer_dim = hidden_layer_dim
# 1. RNN layer.
cells = []
for _ in range(num_rnn_layers):
# TODO(dusenberrymw): Determine if a grad-clipped version is needed.
lstm_cell = rank1_bnn_layers.LSTMCellRank1(
rnn_dim,
alpha_initializer=rank1_utils.make_initializer(
alpha_initializer, random_sign_init, dropout_rate),
gamma_initializer=rank1_utils.make_initializer(
gamma_initializer, random_sign_init, dropout_rate),
recurrent_alpha_initializer=rank1_utils.make_initializer(
alpha_initializer, random_sign_init, dropout_rate),
recurrent_gamma_initializer=rank1_utils.make_initializer(
gamma_initializer, random_sign_init, dropout_rate),
alpha_regularizer=rank1_utils.make_regularizer(
alpha_regularizer, prior_mean, prior_stddev),
gamma_regularizer=rank1_utils.make_regularizer(
gamma_regularizer, prior_mean, prior_stddev),
recurrent_alpha_regularizer=rank1_utils.make_regularizer(
alpha_regularizer, prior_mean, prior_stddev),
recurrent_gamma_regularizer=rank1_utils.make_regularizer(