Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return random.normal(key, size)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return jnp.tril(random.uniform(key, size))
elif isinstance(constraint, constraints._PositiveDefinite):
x = random.normal(key, size)
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x - random.normal(key, size[:-1])
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
self.n_buckets % factor == 0 and
factor % 2 == 0 and
(self.n_buckets // factor) % 2 == 0):
factor -= 1
if factor > 2: # Factor of 2 does not warrant the effort.
rot_size = factor + (self.n_buckets // factor)
factor_list = [factor, self.n_buckets // factor]
random_rotations_shape = (
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
rng = jax.lax.tie_in(vecs, rng)
rng, subrng = backend.random.split(rng)
random_rotations = jax.random.normal(
rng, random_rotations_shape).astype('float32')
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
def compute_element(i):
return np.dot(rhs, row(i))
return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
def do_mvm(rhs):
@jit
def compute_element(i):
return np.dot(rhs, row(i))
return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
return do_mvm
def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = jnp.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = jnp.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
+ tuple(range(sample_ndim, bx.ndim - 1, 2))
+ tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
+ (bx.ndim - 1,))
bx = jnp.transpose(bx, permute_dims)
# reshape to (-1, i, 1, n)
xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
# permute to (i, 1, n, -1)
def test_warmup_adapter(jitted):
def find_reasonable_step_size(step_size, m_inv, z, rng_key):
return jnp.where(step_size < 1, step_size * 4, step_size / 4)
num_steps = 150
adaptation_schedule = build_adaptation_schedule(num_steps)
init_step_size = 1.
mass_matrix_size = 3
wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
wa_update = jit(wa_update) if jitted else wa_update
rng_key = random.PRNGKey(0)
z = jnp.ones(3)
wa_state = wa_init((z, None, None, None), rng_key, init_step_size, mass_matrix_size=mass_matrix_size)
step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
assert step_size == find_reasonable_step_size(init_step_size, inverse_mass_matrix, z, rng_key)
assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))
assert window_idx == 0
window = adaptation_schedule[0]
for t in range(window.start, window.end + 1):
wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state)
last_step_size = step_size
step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
assert window_idx == 1
# step_size is decreased because accept_prob < target_accept_prob
assert step_size < last_step_size
# inverse_mass_matrix does not change at the end of the first window
def test_param():
# this test the validity of model/guide sites having
# param constraints contain composed transformed
rng_keys = random.split(random.PRNGKey(0), 5)
a_minval = 1
c_minval = -2
c_maxval = -1
a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
b_init = jnp.exp(random.normal(rng_keys[1]))
c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
d_init = random.uniform(rng_keys[3])
obs = random.normal(rng_keys[4])
def model():
a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
b = numpyro.param('b', b_init, constraint=constraints.positive)
numpyro.sample('x', dist.Normal(a, b), obs=obs)
def guide():
c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
def test_external_submodule2():
layer = Dense(2, zeros, zeros)
@parametrized
def net(inputs):
return layer(inputs)
inputs = np.zeros((1, 2))
params = net.init_params(PRNGKey(0), inputs)
assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params)
out = net.apply(params, inputs)
assert np.array_equal(np.zeros((1, 2)), out)
out_ = jit(net.apply)(params, inputs)
assert np.array_equal(out, out_)
def test_external_submodule2():
layer = Dense(2, zeros, zeros)
@parametrized
def net(inputs):
return layer(inputs)
inputs = np.zeros((1, 2))
params = net.init_params(PRNGKey(0), inputs)
assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params)
out = net.apply(params, inputs)
assert np.array_equal(np.zeros((1, 2)), out)
out_ = jit(net.apply)(params, inputs)
assert np.array_equal(out, out_)