Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model_nested_plates_1():
with numpyro.plate('outer', 10, dim=-2):
x = numpyro.sample('y', dist.Normal(0., 1.))
assert x.shape == (10, 1)
with numpyro.plate('inner', 5):
y = numpyro.sample('x', dist.Normal(0., 1.))
assert y.shape == (10, 5)
z = numpyro.deterministic('z', x ** 2)
assert z.shape == (10, 1)
def dirichlet_categorical(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p', dist.Dirichlet(concentration))
with numpyro.plate('N', data.shape[0]):
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def model_subsample_1():
outer = numpyro.plate('outer', 20, subsample_size=10)
inner = numpyro.plate('inner', 10, subsample_size=5, dim=-3)
with outer:
x = numpyro.sample('x', dist.Normal(0., 1.))
assert x.shape == (10,)
with inner:
y = numpyro.sample('y', dist.Normal(0., 1.))
assert y.shape == (5, 1, 1)
z = numpyro.deterministic('z', x ** 2)
assert z.shape == (10,)
with outer, inner:
xy = numpyro.sample('xy', dist.Normal(0., 1.))
assert xy.shape == (5, 1, 10)
def model(data=None):
beta = numpyro.sample("beta", dist.Beta(np.ones(2), np.ones(2)))
with numpyro.plate("plate", N, dim=-2):
numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
def neals_funnel(dim):
y = numpyro.sample('y', dist.Normal(0, 3))
with numpyro.plate('D', dim):
numpyro.sample('x', dist.Normal(0, jnp.exp(y / 2)))
def model(data):
f = numpyro.sample("beta", dist.Beta(1., 1.))
with numpyro.plate("plate", 10):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
This is useful when doing inference for the usual NumPyro programs with
`numpyro.plate` statements. For example, to get trace of a `model` whose discrete
latent sites are enumerated, we can use
enum_model = numpyro.contrib.funsor.enum(model)
with plate_to_enum_plate():
model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace(
*model_args, **model_kwargs)
"""
try:
numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs)
yield
finally:
numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)
def partially_pooled(at_bats, hits=None):
r"""
Number of hits has a Binomial distribution with independent
probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
distribution with concentration parameters $c_1$ and $c_2$, where
$c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
and $kappa ~ Pareto(1, 1.5)$.
:param (jnp.DeviceArray) at_bats: Number of at bats for each player.
:param (jnp.DeviceArray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
m = numpyro.sample("m", dist.Uniform(0, 1))
kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
phi = numpyro.sample("phi", phi_prior)
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
def not_pooled(at_bats, hits=None):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with independent probability of success, $\phi_i$.
:param (jnp.DeviceArray) at_bats: Number of at bats for each player.
:param (jnp.DeviceArray) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
with numpyro.plate("num_players", num_players):
phi_prior = dist.Uniform(0, 1)
phi = numpyro.sample("phi", phi_prior)
return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)