Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@base.layer()
def FastGelu(x, **unused_kwargs):
return 0.5 * x * (1 + np.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trax normalization layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from trax.backend import numpy as np
from trax.layers import base
class BatchNorm(base.Layer):
"""Batch normalization."""
def __init__(self, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
momentum=0.999, mode='train'):
super(BatchNorm, self).__init__()
self._axis = axis
self._epsilon = epsilon
self._center = center
self._scale = scale
self._momentum = momentum
self._mode = mode
def new_weights_and_state(self, input_signature):
"""Helper to initialize batch norm weights."""
axis = self._axis
axis = (axis,) if np.isscalar(axis) else axis
@base.layer()
def SumPool(x, weights, pool_size=(2, 2), strides=None, padding='VALID', **kw):
del weights, kw
return backend.sum_pool(x, pool_size=pool_size, strides=strides,
padding=padding)
@base.layer(n_in=2, n_out=1)
def CrossEntropy(x, axis=-1, **kw):
del kw
prediction, target = x
return np.sum(prediction * core.one_hot(target, prediction.shape[-1]),
axis=axis)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
state=base.EMPTY_STATE, rng=None, **kwargs):
del weights, kwargs
if self._mode in ('train', 'eval'):
output = self._forward_train_eval(inputs, rng)
return (output, state)
else:
assert self._mode == 'predict'
return self._forward_predict(inputs, state, rng)
memory_transform_fn: Optional transformation on the memory before gating.
gate_nonlinearity: Function to use as gate activation. Allows trying
alternatives to Sigmoid, such as HardSigmoid.
candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
trying alternatives to traditional Tanh, such as HardTanh
dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
best in a GRU when applied exclusively to this branch.
sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
off with a positive bias.
Returns:
A model representing a GRU cell with specified transforms.
"""
gate_block = [ # u_t
candidate_transform(),
base.Fn(lambda x: x + sigmoid_bias),
gate_nonlinearity(),
]
reset_block = [ # r_t
candidate_transform(),
base.Fn(lambda x: x + sigmoid_bias), # Want bias to start positive.
gate_nonlinearity(),
]
candidate_block = [
cb.Dup(),
reset_block,
cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
candidate_transform(), # Final projection + tanh to get Ct
candidate_nonlinearity(), # Candidate gate
# Only apply dropout on the C gate. Paper reports 0.1 as a good default.
core.Dropout(rate=dropout_rate_c)
def new_weights_and_state(self, input_signature):
d_feature = input_signature.shape[-1]
pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32)
position = onp.arange(0, self._max_len)[:, onp.newaxis]
div_term = onp.exp(
onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
pe[:, 0::2] = onp.sin(position * div_term)
pe[:, 1::2] = onp.cos(position * div_term)
pe = pe[onp.newaxis, :, :] # [1, self._max_len, d_feature]
weights = np.array(pe) # These are trainable parameters, initialized above.
state = 0 if self._mode == 'predict' else base.EMPTY_STATE
return weights, state
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
state=base.EMPTY_STATE, rng=None, **kwargs):
del weights
q, k, v = inputs
if self._mode in ('train', 'eval'):
mask_size = q.shape[-2]
# Not all backends define np.tril. However, using onp.tril is inefficient
# in that it creates a large global constant. TODO(kitaev): try to find an
# alternative that works across all backends.
if backend.get_name() == 'jax':
mask = np.tril(
np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
else:
mask = onp.tril(
onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
else:
assert self._mode == 'predict'
@base.layer(new_weights_fn=_layer_norm_weights)
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs): # pylint: disable=invalid-name
(scale, bias) = weights
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
return norm_inputs * scale + bias