Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_BOLFI():
m, true_params = setup_ma2_with_informative_data()
# Log discrepancy tends to work better
log_d = NodeReference(m['d'], state=dict(_operation=np.log), model=m, name='log_d')
bolfi = elfi.BOLFI(
log_d,
initial_evidence=20,
update_interval=10,
batch_size=5,
bounds={'t1': (-2, 2),
't2': (-1, 1)},
acq_noise_var=.1)
n = 300
res = bolfi.infer(300)
assert bolfi.target_model.n_evidence == 300
acq_x = bolfi.target_model._gp.X
# check_inference_with_informative_data(res, 1, true_params, error_bound=.2)
assert np.abs(res.x_min['t1'] - true_params['t1']) < 0.2
def test_name_argument(self):
# This is important because it is used when passing NodeReferences as
# InferenceMethod arguments
em.set_current_model()
ref = em.NodeReference(name='test')
assert str(ref) == 'test'
def test_name_determination(self):
em.set_current_model()
node = em.NodeReference()
assert node.name == 'node'
# Works without spaces
node2=em.NodeReference()
assert node2.name == 'node2'
# Does not give the same name
node = em.NodeReference()
assert node.name != 'node'
# Works with sub classes
pri = em.Prior('uniform')
assert pri.name == 'pri'
# Assigns random names when the name isn't self explanatory
nodes = []
for i in range(5):
nodes.append(em.NodeReference())
for i in range(1,5):
return self.state[item]
def __setitem__(self, item, value):
"""Set item into the state dict of the node."""
self.state[item] = value
def __repr__(self):
"""Return a representation comprised of the names of the class and the node."""
return "{}(name='{}')".format(self.__class__.__name__, self.name)
def __str__(self):
"""Return the name of the node."""
return self.name
class StochasticMixin(NodeReference):
"""Define the inheriting node as stochastic.
Operations of stochastic nodes will receive a `random_state` keyword argument.
"""
def __init__(self, *parents, state, **kwargs):
# Flag that this node is stochastic
state['_stochastic'] = True
super(StochasticMixin, self).__init__(*parents, state=state, **kwargs)
class ObservableMixin(NodeReference):
"""Define the inheriting node as observable.
Observable nodes accept observed keyword argument. In addition the compiled
model will contain a sister node that contains the observed value or will compute the
Parameters
----------
model : elfi.ElfiModel
nodes : list
Either a list of node names or a list of node reference objects
reduce_operation : callable
name : str
Name for the reduce node
Returns
-------
name : str
name of the new node
"""
name = '_reduce*' if name is None else name
nodes = [n if isinstance(n, NodeReference) else model[n] for n in nodes]
op = Operation(compose(partial(reduce, reduce_operation), args_to_tuple), *nodes,
model=model, name=name)
return op.name
"""Make this node become the `other_node`.
The children of this node will be preserved.
Parameters
----------
other_node : NodeReference
"""
if other_node.model is not self.model:
raise ValueError('The other node belongs to a different model')
self.model.update_node(self.name, other_node.name)
# Update the reference class
_class = self.state.get('_class', NodeReference)
if not isinstance(self, _class):
self.__class__ = _class
# Update also the other node reference
other_node.name = self.name
other_node.model = self.model
Requires the optional 'graphviz' library.
Returns
-------
dot
A GraphViz dot representation of the model.
"""
try:
from graphviz import Digraph
except ImportError:
raise ImportError("The graphviz library is required for this feature.")
if isinstance(G, ElfiModel):
G = G.source_net
elif isinstance(G, NodeReference):
G = G.model.source_net
dot = Digraph(format=format)
hidden = set()
for n, state in G.nodes_iter(data=True):
if not internal and n[0] == '_' and state.get('_class') == Constant:
hidden.add(n)
continue
_format = {'shape': 'circle', 'fillcolor': 'gray80', 'style': 'solid'}
if state.get('_observable'):
_format['style'] = 'filled'
dot.node(n, **_format)
# add edges to graph
"""A node holding a constant value."""
def __init__(self, value, **kwargs):
"""Initialize a node holding a constant value.
Parameters
----------
value
The constant value of the node.
"""
state = dict(_output=value)
super(Constant, self).__init__(state=state, **kwargs)
class Operation(NodeReference):
"""A generic deterministic operation node."""
def __init__(self, fn, *parents, **kwargs):
"""Initialize a node that performs an operation.
Parameters
----------
fn : callable
The operation of the node.
"""
state = dict(_operation=fn)
super(Operation, self).__init__(*parents, state=state, **kwargs)
class RandomVariable(StochasticMixin, NodeReference):
-----
The parameters of the `scipy` distributions (typically `loc` and `scale`) must be
given as positional arguments.
Many algorithms (e.g. SMC) also require a `pdf` method for the distribution. In
general the definition of the distribution is a subset of
`scipy.stats.rv_continuous`.
Scipy distributions: https://docs.scipy.org/doc/scipy-0.19.0/reference/stats.html
"""
super(Prior, self).__init__(distribution, *params, size=size, **kwargs)
self['_parameter'] = True
class Simulator(StochasticMixin, ObservableMixin, NodeReference):
"""A simulator node of an ELFI graph.
Simulator nodes are stochastic and may have observed data in the model.
"""
def __init__(self, fn, *params, **kwargs):
"""Initialize a Simulator.
Parameters
----------
fn : callable
Simulator function with a signature `sim(*params, batch_size, random_state)`
params
Input parameters for the simulator.
kwargs
Parameters
----------
fn : callable
Summary function with a signature `summary(*parents)`
parents
Input data for the summary function.
kwargs
"""
if not parents:
raise ValueError('This node requires that at least one parent is specified.')
state = dict(_operation=fn)
super(Summary, self).__init__(*parents, state=state, **kwargs)
class Discrepancy(NodeReference):
"""A discrepancy node of an ELFI graph.
This class provides a convenience node for custom distance operations.
"""
def __init__(self, discrepancy, *parents, **kwargs):
"""Initialize a Discrepancy.
Parameters
----------
discrepancy : callable
Signature of the discrepancy function is of the form:
`discrepancy(summary_1, summary_2, ..., observed)`, where summaries are
arrays containing `batch_size` simulated values and observed is a tuple
(observed_summary_1, observed_summary_2, ...). The callable object should
return a vector of discrepancies between the simulated summaries and the