Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import constants
import util
import embedding
@tf.custom_gradient
def _clip_gradient(x, clip_norm):
"""Identity function that performs gradient clipping."""
def grad(dy):
# NOTE: Must return a gradient for all inputs to `clip_gradient`.
return tf.clip_by_global_norm([dy], clip_norm)[0][0], tf.constant(0.)
return x, grad
class LSTMCellReparameterizationGradClipped(
ed.layers.LSTMCellReparameterization):
"""Bayesian LSTMCell with per-time step gradient clipping."""
def __init__(self, clip_norm, *args, **kwargs):
self.clip_norm = clip_norm
super(LSTMCellReparameterizationGradClipped, self).__init__(*args, **kwargs)
def call(self, inputs, states, training=None):
with tf.name_scope("clip_gradient"):
inputs = _clip_gradient(inputs, self.clip_norm)
states = [_clip_gradient(x, self.clip_norm) for x in states]
return super(LSTMCellReparameterizationGradClipped, self).call(
inputs, states, training)
class BayesianRNN(tf.keras.Model):
"""Bayesian RNN model."""