How to use the edward2.experimental.rank1_bnns.rank1_bnn_layers.LSTMCellRank1 function in edward2

To help you get started, we’ve selected a few edward2 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Google-Health / records-research / model-uncertainty / bayesian_rnn_model.py View on Github external
prior_mean: Mean of the prior.
      prior_stddev: Standard deviation of the prior.
      l2: Amount of L2 regularization to apply to the deterministic weights.
      clip_norm: Gradient clipping norm value for per-time step clipping within
        the LSTM cell, and for clipping of all aggregated gradients.
      return_sequences: Whether or not to return outputs at each time step from
        the LSTM, rather than just the final time step.
    """
    super().__init__()
    self.hidden_layer_dim = hidden_layer_dim

    # 1. RNN layer.
    cells = []
    for _ in range(num_rnn_layers):
      # TODO(dusenberrymw): Determine if a grad-clipped version is needed.
      lstm_cell = rank1_bnn_layers.LSTMCellRank1(
          rnn_dim,
          alpha_initializer=rank1_utils.make_initializer(
              alpha_initializer, random_sign_init, dropout_rate),
          gamma_initializer=rank1_utils.make_initializer(
              gamma_initializer, random_sign_init, dropout_rate),
          recurrent_alpha_initializer=rank1_utils.make_initializer(
              alpha_initializer, random_sign_init, dropout_rate),
          recurrent_gamma_initializer=rank1_utils.make_initializer(
              gamma_initializer, random_sign_init, dropout_rate),
          alpha_regularizer=rank1_utils.make_regularizer(
              alpha_regularizer, prior_mean, prior_stddev),
          gamma_regularizer=rank1_utils.make_regularizer(
              gamma_regularizer, prior_mean, prior_stddev),
          recurrent_alpha_regularizer=rank1_utils.make_regularizer(
              alpha_regularizer, prior_mean, prior_stddev),
          recurrent_gamma_regularizer=rank1_utils.make_regularizer(