Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Returns random values for initializing weights of the given `shape`."""
fan_in, fan_out = _GetFans(shape, out_dim, in_dim)
gain = scale
if mode == 'fan_in':
gain /= fan_in
elif mode == 'fan_out':
gain /= fan_out
elif mode == 'fan_avg':
gain /= (fan_in + fan_out) / 2
if distribution == 'truncated_normal':
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = np.sqrt(gain) / .87962566103423978
new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
return new_weights.astype('float32')
elif distribution == 'normal':
new_weights = random.normal(rng, shape) * np.sqrt(gain)
return new_weights.astype('float32')
elif distribution == 'uniform':
lim = np.sqrt(3. * gain)
return random.uniform(rng, shape, np.float32, -lim, lim)
else:
raise ValueError('invalid distribution for ScaleInitializer')
def learning_rate(step): # pylint: disable=invalid-name
"""Step to learning rate function."""
ret = 1.0
for name in factors:
if name == 'constant':
ret *= constant
elif name == 'linear_warmup':
ret *= np.minimum(1.0, step / warmup_steps)
elif name == 'rsqrt_decay':
ret /= np.sqrt(np.maximum(step, warmup_steps))
elif name == 'rsqrt_normalized_decay':
ret *= np.sqrt(warmup_steps)
ret /= np.sqrt(np.maximum(step, warmup_steps))
elif name == 'decay_every':
ret *= (decay_factor ** (step//steps_per_decay))
elif name == 'cosine_decay':
progress = np.maximum(
0.0, (step - warmup_steps) / float(steps_per_cycle))
ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
else:
raise ValueError('Unknown factor %s.' % name)
ret = np.asarray(ret, dtype=np.float32)
return {'learning_rate': ret}
# same bucket, so this increases the chances of attending to relevant items.
# TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
def look_one_back(x):
if len(x.shape) == 2:
x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
else:
x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
return np.concatenate([x, x_extra], axis=1)
bk = look_one_back(bk)
bv = look_one_back(bv)
bkv_t = look_one_back(bkv_t)
bkv_buckets = look_one_back(bkv_buckets)
# Dot-product attention.
dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])
# Causal masking
mask = jax.lax.convert_element_type(
jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]),
np.float32)
dots = dots - 1e9 * mask
# Mask out attention to self except when no other targets are available.
self_mask = jax.lax.convert_element_type(
jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]),
np.float32)
dots = dots - 1e5 * self_mask
# Mask out attention to other hash buckets.
if not self._attend_across_buckets:
bucket_mask = jax.lax.convert_element_type(
def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves = tree_flatten(tree)
return np.sqrt(sum(np.vdot(x, x) for x in leaves))
def _update_diagonal(self, grads, weights, m, v, opt_params):
learning_rate = opt_params['learning_rate']
momentum = opt_params['momentum']
v[0] += grads * grads
preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
np.zeros_like(v[0]))
preconditioned_grads = preconditioner * grads
m = (1 - momentum) * preconditioned_grads + momentum * m
weights = weights - (learning_rate * m).astype(weights.dtype)
return weights, (m, v)
if mode == 'fan_in':
gain /= fan_in
elif mode == 'fan_out':
gain /= fan_out
elif mode == 'fan_avg':
gain /= (fan_in + fan_out) / 2
if distribution == 'truncated_normal':
# constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = np.sqrt(gain) / .87962566103423978
new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
return new_weights.astype('float32')
elif distribution == 'normal':
new_weights = random.normal(rng, shape) * np.sqrt(gain)
return new_weights.astype('float32')
elif distribution == 'uniform':
lim = np.sqrt(3. * gain)
return random.uniform(rng, shape, np.float32, -lim, lim)
else:
raise ValueError('invalid distribution for ScaleInitializer')
def learning_rate(step): # pylint: disable=invalid-name
"""Step to learning rate function."""
ret = 1.0
for name in factors:
if name == 'constant':
ret *= constant
elif name == 'linear_warmup':
ret *= np.minimum(1.0, step / warmup_steps)
elif name == 'rsqrt_decay':
ret /= np.sqrt(np.maximum(step, warmup_steps))
elif name == 'rsqrt_normalized_decay':
ret *= np.sqrt(warmup_steps)
ret /= np.sqrt(np.maximum(step, warmup_steps))
elif name == 'decay_every':
ret *= (decay_factor ** (step//steps_per_decay))
elif name == 'cosine_decay':
progress = np.maximum(
0.0, (step - warmup_steps) / float(steps_per_cycle))
ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
else:
raise ValueError('Unknown factor %s.' % name)
ret = np.asarray(ret, dtype=np.float32)
return {'learning_rate': ret}