Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
numpyro.sample('y', dist.Normal(c, d), obs=obs)
adam = optim.Adam(0.01)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(0))
params = svi.get_params(svi_state)
assert_allclose(params['a'], a_init)
assert_allclose(params['b'], b_init)
assert_allclose(params['c'], c_init)
assert_allclose(params['d'], d_init)
actual_loss = svi.evaluate(svi_state)
assert jnp.isfinite(actual_loss)
expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
# not so precisely because we do transform / inverse transform stuffs
assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def model(a=None, b=None, z=None):
int_term = numpyro.sample('a', dist.Normal(0., 0.2))
x_term, y_term = 0., 0.
if a is not None:
x = numpyro.sample('x', dist.HalfNormal(0.5))
x_term = a * x
if b is not None:
y = numpyro.sample('y', dist.HalfNormal(0.5))
y_term = b * y
sigma = numpyro.sample('sigma', dist.Exponential(1.))
mu = int_term + x_term + y_term
numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)
def model(data):
mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ()))
return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
def schools_model():
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],))
numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])
def model(data):
loc = numpyro.sample('loc', dist.Normal(0., 1.))
numpyro.sample('obs', dist.Normal(loc, 1), obs=data)
def _sample():
x = numpyro.sample('x', dist.Normal(0., 1.))
y = numpyro.sample('y', dist.Normal(1., 2.))
return jnp.stack([x, y])
def test_mask(batch_shape, event_shape, mask_shape):
jax_dist = dist.Normal().expand(batch_shape + event_shape).to_event(len(event_shape))
mask = dist.Bernoulli(0.5).sample(random.PRNGKey(0), mask_shape)
if mask_shape == ():
mask = bool(mask)
samples = jax_dist.sample(random.PRNGKey(1))
actual = jax_dist.mask(mask).log_prob(samples)
assert_allclose(actual != 0, jnp.broadcast_to(mask, lax.broadcast_shapes(batch_shape, mask_shape)))
def reparam_model(dim=10):
y = numpyro.sample('y', dist.Normal(0, 3))
with numpyro.handlers.reparam(config={'x': LocScaleReparam(0)}):
numpyro.sample('x', dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))