Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
while factor > 0 and not (
self.n_buckets % factor == 0 and
factor % 2 == 0 and
(self.n_buckets // factor) % 2 == 0):
factor -= 1
if factor > 2: # Factor of 2 does not warrant the effort.
rot_size = factor + (self.n_buckets // factor)
factor_list = [factor, self.n_buckets // factor]
rotations_shape = (
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
rng = jax.lax.tie_in(vecs, rng)
rng, subrng = backend.random.split(rng)
random_rotations = self._sample_rotation(rotations_shape, vecs, rng)
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
reconstructed_x = output
rng = kwargs.pop('rng', None)
rngs = (None,) * self._n_layers
if rng is not None:
rngs = backend.random.split(rng, self._n_layers)
# Note that self.sublayers aligns exactly with self.reverse_layers in
# terms of parameter and rng usage, so no re-ordering is required.
for layer, p, s, ns, rng in zip(
self.reverse_layers, weights, state, new_state, rngs):
reconstructed_x = layer(reconstructed_x, weights=p,
state=s, new_state=ns, rng=rng, **kwargs)
return reconstructed_x
def mapped_compute_loss(opt_state, batch, state, rng):
"""This is a multi-device version of the update function above."""
# We assume all tensors have the first dimension = n_devices.
rng, subrng = jax_random.split(rng)
loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
return loss_val, state, subrng
def mapped_update(i, opt_state, batch, state, rng):
"""This is a multi-device version of the update function above."""
# We assume all tensors have the first dimension = n_devices.
weights, slots, opt_params = opt_state
rng, subrng = jax_random.split(rng)
grad_fn = backend.grad(model_and_loss_call, has_aux=True)
grads, state = grad_fn(weights, batch, state, rng)
# We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
# the number of devices on this host machine, however psum goes over all
# devices of all hosts (ex: a TPU pod) and we need to be averaging over all
# of them.
grads = jax.tree_util.tree_map(
lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
return optimizer.tree_update(
i, grads, weights, slots, opt_params), state, subrng
def body_fun(vals):
"""Performs attention for a single batch element and head."""
batch_loop_idx = vals[0]
if self._prng is None:
hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
hash_rng, slice_rng = backend.random.split(hash_slice_rng)
else:
# TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
slice_rng = jax.random.fold_in(rng, batch_loop_idx)
qk_slice = jax.lax.dynamic_index_in_dim(
qk, batch_loop_idx, axis=0, keepdims=False)
v_slice = jax.lax.dynamic_index_in_dim(
v, batch_loop_idx, axis=0, keepdims=False)
if buckets is None:
buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
else:
buckets_slice = jax.lax.dynamic_index_in_dim(
buckets, batch_loop_idx, axis=0, keepdims=False)
if ct is None:
def evaluate(self, n_eval_steps):
"""Evaluate the model and log metrics."""
_, rng = jax_random.split(self._rngs[0])
# TODO(lukaszkaiser): both model state and parameters by default include
# the loss layer. Currently, we access the pure-model parameters by just
# indexing, [0] here. But we should make it more explicit in a better API.
weights = (self._opt_state[0][0], self._metrics_weights)
state = (self._model_state[0], self._metrics_state)
self.log_step('Evaluation')
train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps)
train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state,
rng)
self.log_metrics(train_metrics, self._train_sw, 'train')
eval_slice = itertools.islice(self._eval_stream, n_eval_steps)
eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng)
self.log_metrics(eval_metrics, self._eval_sw, 'eval')
self.log_step('Finished evaluation')
# Save the optimizer weights in the history
self._mask_id = mask_id
self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
inputs = inputs(self._n_devices)
self._inputs = inputs
# Initialize the learning rate to a dummy value. It will be set in reset().
opt = optimizer(learning_rate=0.0)
# Setup the model.
model_train = model(mode='train')
model_predict_eval = model(mode='eval')
# Setup state.
rng, init_rng = jax_random.split(rng)
self._rngs = np.stack(jax_random.split(rng, self._n_devices))
first_shape = inputs.input_shape[0]
# If the inputs are a tuple/list, add [None] (batch) to each element.
if isinstance(first_shape, (list, tuple)):
model_input_shape = tuple(
tuple([None] + list(shape)) for shape in inputs.input_shape)
model_target_shape = tuple(
tuple([None] + list(shape)) for shape in inputs.target_shape)
else: # Otherwise just add [None] to the input shape.
model_input_shape = tuple([None] + list(inputs.input_shape))
model_target_shape = tuple([None] + list(inputs.target_shape))
# Change all None to 1 in input and target shape.
model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
model_target_shape = backend.nested_map(lambda x: x or 1,
model_target_shape)
def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,