Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run_source(obj):
_obj = source._wrap(obj)
assert _obj(1.57) == obj(1.57)
src = source.getimportable(obj, alias='_f')
exec(src, globals())
assert _f(1.57) == obj(1.57)
name = source.getname(obj)
assert name == obj.__name__ or src.split("=",1)[0].strip()
double_add.invert()
assert double_add(1,2,3) == -2*fx
_d = dill.copy(double_add)
assert _d(1,2,3) == -2*fx
#_d.invert() #FIXME: fails seemingly randomly
#assert _d(1,2,3) == 2*fx
assert _d.__wrapped__(1,2,3) == fx
# XXX: issue or feature? in python3.4, inverted is linked through copy
if not double_add.inverted[0]:
double_add.invert()
# test some stuff from source and pointers
ds = dill.source
dd = dill.detect
assert ds.getsource(dd.freevars(quadish)['f']) == '@quad_factory(a=0,b=4,c=0)\ndef quadish(x):\n return x+1\n'
assert ds.getsource(dd.freevars(quadruple)['f']) == '@doubler\ndef quadruple(x):\n return 2*x\n'
assert ds.importable(quadish, source=False) == 'from %s import quadish\n' % __name__
assert ds.importable(quadruple, source=False) == 'from %s import quadruple\n' % __name__
assert ds.importable(quadratic, source=False) == 'from %s import quadratic\n' % __name__
assert ds.importable(double_add, source=False) == 'from %s import double_add\n' % __name__
assert ds.importable(quadruple, source=True) == 'def doubler(f):\n def inner(*args, **kwds):\n fx = f(*args, **kwds)\n return 2*fx\n return inner\n\n@doubler\ndef quadruple(x):\n return 2*x\n'
#***** #FIXME: this needs work
result = ds.importable(quadish, source=True)
a,b,c,_,result = result.split('\n',4)
assert result == 'def quad_factory(a=1,b=1,c=0):\n def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n return dec\n\n@quad_factory(a=0,b=4,c=0)\ndef quadish(x):\n return x+1\n'
assert set([a,b,c]) == set(['a = 0', 'c = 0', 'b = 4'])
result = ds.importable(quadratic, source=True)
a,b,c,result = result.split('\n',3)
assert result == '\ndef dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n'
def run_source(obj):
_obj = source._wrap(obj)
assert _obj(1.57) == obj(1.57)
src = source.importable(obj, alias='_f')
# LEEK: for 3.x, locals may not be modified
# (see https://docs.python.org/3.6/library/functions.html#locals)
#
my_locals = locals()
exec(src, globals(), my_locals)
assert my_locals["_f"](1.57) == obj(1.57)
name = source.getname(obj)
assert name == obj.__name__ or src.split("=",1)[0].strip()
def _hash_filter_fn(self, filter_fn, **kwargs):
""" Construct string representing state of filter_fn
Used to cache filtered variants or effects uniquely depending on filter fn values
"""
filter_fn_name = self._get_function_name(filter_fn, default="filter-none")
logger.debug("Computing hash for filter_fn: {} with kwargs {}".format(filter_fn_name, str(dict(**kwargs))))
# hash function source code
fn_source = str(dill.source.getsource(filter_fn))
pickled_fn_source = pickle.dumps(fn_source) ## encode as byte string
hashed_fn_source = int(hashlib.sha1(pickled_fn_source).hexdigest(), 16) % (10 ** 11)
# hash kwarg values
kw_dict = dict(**kwargs)
kw_hash = list()
if not kw_dict:
kw_hash = ["default"]
else:
[kw_hash.append("{}-{}".format(key, h)) for (key, h) in sorted(kw_dict.items())]
# hash closure vars - for case where filter_fn is defined within closure of filter_fn
closure = []
nonlocals = inspect.getclosurevars(filter_fn).nonlocals
for (key, val) in nonlocals.items():
## capture hash for any function within closure
if inspect.isfunction(val):
closure.append(self._hash_filter_fn(val))
import dill
if missing is None: _mask = {}
elif isinstance(missing, str): _mask = eval('{%s}' % missing)
else: _mask = missing
# raise KeyError if key out of bounds #XXX: also has *any* non-int object
first = min([0]+list(_mask.keys()))
if first < 0:
raise KeyError('invalid argument index: %s' % first)
last = max([-1]+list(_mask.keys()))
if last > len(x)+len(_mask)-1:
raise KeyError('invalid argument index: %s' % last)
# preserve type(x)
_locals = {}
_locals['xtype'] = type(x)
code = "%s" % dill.source.getimport(x, alias='xtype')
if "import" in code:
code = compile(code, '', 'exec')
exec(code, _locals)
xtype = _locals['xtype']
# find the new indices due to the mask
_x = list(x)
for (k,v) in sorted(_mask.items()):
_x.insert(k,v)
# get the new sequence
return xtype(_x)
"""
Render code object
"""
cobj = obj.__code__
global data
data["co"] = {
attr: getattr(cobj, attr)
for attr in dir(cobj)
if attr.startswith("co_")
}
data["co"]["co_code"] = data["co"]["co_code"].hex()
data["tpl_t"] = "CO"
data["ins"] = list(instructions)
(lines, start_line) = source.getsourcelines(obj)
src = "".join(lines)
tree = ast.parse(src, cobj.co_filename)
nodes = node_to_dict(tree, None)
data["nodes"] = dedupe_nodes(nodes)
data["src"] = src
data["last_line"] = start_line + len(lines)
start()
code += '\n\t'.join(function_code)
code += '\n\n'
_recurse_code_update(module_tree)
call_obj = self.name.capitalize().replace('-', '_').replace('.', '_') + 'CLI'
else:
if hasattr(sys, 'ps1'):
# Running interactively
try:
# dill works in interactive mode, inspect.getsource()
# doesn't
import dill
func_code = dill.source.getsource(self.module_or_function)
except ImportError:
try:
func_code = inspect.getsource(self.module_or_function)
except OSError:
func_code = ''
else:
func_code = inspect.getsource(self.module_or_function)
code = code.replace(
'ADDITIONAL_IMPORTS/:',
'\n' +
'\n'.join(
self.additional_imports) +
"""
Adds a forward model factory to the cost factory.
Inputs:
model -- a callable function factory object
inputs -- number of input arguments to model
name -- a string representing the model name
Example:
>>> import numpy as np
>>> C = CostFactory()
>>> C.addModel(np.poly, inputs=3)
"""
if name is None:
import dill
name = dill.source.getname(model)
if name is None:
for i in range(len(self._names)+1):
name = 'model'+str(i)
if name not in self._names: break
elif name in self._names:
print("Model [%s] already in database." % name)
raise AssertionError
self._names.append(name)
self._forwardFactories.append(model)
self._inputs.append(inputs)
self._outputFilters.append(outputFilter)
self._inputCheckers.append(inputChecker)
"""
warn_use_state = False
if FLAMBE_DIRECTORIES_KEY not in state_dict._metadata:
state_dict._metadata[FLAMBE_DIRECTORIES_KEY] = set()
warn_use_state = True
if KEEP_VARS_KEY not in state_dict._metadata:
state_dict._metadata[KEEP_VARS_KEY] = False
warn_use_state = True
if warn_use_state:
warn("Use '.get_state()' on flambe objects, not state_dict "
f"(from {type(self).__name__})")
# 1 need to add in any extras like config
local_metadata[VERSION_KEY] = self._flambe_version
local_metadata[FLAMBE_CLASS_KEY] = type(self).__name__
local_metadata[FLAMBE_SOURCE_KEY] = dill.source.getsource(type(self))
# All links should be relative to the current object `self`
with contextualized_linking(root_obj=self, prefix=prefix[:-1]):
try:
local_metadata[FLAMBE_CONFIG_KEY] = self._config_str
global _link_obj_stash
if len(_link_obj_stash) > 0:
local_metadata[FLAMBE_STASH_KEY] = copy.deepcopy(_link_obj_stash)
except AttributeError:
pass
# 2 need to recurse on Components
# Iterating over __dict__ does NOT include pytorch children
# modules, parameters or buffers
# torch.optim.Optimizer does exist so ignore mypy
for name, attr in self.__dict__.items():
if isinstance(attr, Component) and not isinstance(attr, (
torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler)): # type: ignore