Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_threadpool_limits_manual_unregister():
# Check that threadpool_limits can be used as an object which holds the
# original state of the threadpools and that can be restored thanks to the
# dedicated unregister method
original_info = _threadpool_info()
limits = threadpool_limits(limits=1)
try:
for module in _threadpool_info():
if is_old_openblas(module):
continue
assert module.num_threads == 1
finally:
# Restore the original limits so that this test does not have any
# side-effect.
limits.unregister()
assert _threadpool_info() == original_info
def test_shipped_openblas():
# checks that OpenBLAS effectively uses the number of threads requested by
# the context manager
original_info = _threadpool_info()
openblas_modules = original_info.get_modules("internal_api", "openblas")
with threadpool_limits(1):
for module in openblas_modules:
assert module.get_num_threads() == 1
assert original_info == _threadpool_info()
def _threadpool_info():
# Like threadpool_info but return the object instead of the list of dicts
return _ThreadpoolInfo(user_api=_ALL_USER_APIS)
def test_command_line_command_flag():
pytest.importorskip("numpy")
output = subprocess.check_output(
["python", "-m", "threadpoolctl", "-c", "import numpy"])
cli_info = json.loads(output.decode("utf-8"))
this_process_info = threadpool_info()
for module in cli_info:
assert module in this_process_info
def test_threadpool_limits_public_api():
# Check consistency between threadpool_info and _ThreadpoolInfo
public_info = threadpool_info()
private_info = _threadpool_info()
for module1, module2 in zip(public_info, private_info):
assert module1 == module2.todict()
def test_command_line_import_flag():
result = subprocess.run([
"python", "-m", "threadpoolctl", "-i",
"numpy",
"scipy.linalg",
"invalid_package",
"numpy.invalid_sumodule",
], capture_output=True, check=True, encoding="utf-8")
cli_info = json.loads(result.stdout)
this_process_info = threadpool_info()
for module in cli_info:
assert module in this_process_info
warnings = [w.strip() for w in result.stderr.splitlines()]
assert "WARNING: could not import invalid_package" in warnings
assert "WARNING: could not import numpy.invalid_sumodule" in warnings
if scipy is None:
assert "WARNING: could not import scipy.linalg" in warnings
else:
assert "WARNING: could not import scipy.linalg" not in warnings
def test_ThreadpoolInfo_todicts():
# Check all keys expected for the public api are in the dicts returned by
# the .todict(s) methods
info = _threadpool_info()
assert threadpool_info() == [module.todict() for module in info.modules]
assert info.todicts() == [module.todict() for module in info]
assert info.todicts() == [module.todict() for module in info.modules]
for module in info:
module_dict = module.todict()
assert "user_api" in module_dict
assert "internal_api" in module_dict
assert "prefix" in module_dict
assert "filepath" in module_dict
assert "version" in module_dict
assert "num_threads" in module_dict
if module.internal_api in ("mkl", "blis", "openblas"):
assert "threading_layer" in module_dict
def test_threadpool_limits_bad_input():
# Check that appropriate errors are raised for invalid arguments
match = re.escape("user_api must be either in {} or None."
.format(_ALL_USER_APIS))
with pytest.raises(ValueError, match=match):
threadpool_limits(limits=1, user_api="wrong")
with pytest.raises(TypeError,
match="limits must either be an int, a list or a dict"):
threadpool_limits(limits=(1, 2, 3))
def test_set_threadpool_limits_by_api(user_api, limit):
# Check that the maximum number of threads can be set by user_api
original_info = _threadpool_info()
modules_matching_api = original_info.get_modules("user_api", user_api)
if not modules_matching_api:
user_apis = _ALL_USER_APIS if user_api is None else [user_api]
pytest.skip("Requires a library which api is in {}".format(user_apis))
with threadpool_limits(limits=limit, user_api=user_api):
for module in modules_matching_api:
if is_old_openblas(module):
continue
# threadpool_limits only sets an upper bound on the number of
# threads.
assert 0 < module.get_num_threads() <= limit
assert _threadpool_info() == original_info
@pytest.mark.parametrize("prefix", _ALL_PREFIXES)
@pytest.mark.parametrize("limit", [1, 3])
def test_threadpool_limits_by_prefix(prefix, limit):
# Check that the maximum number of threads can be set by prefix
original_info = _threadpool_info()
modules_matching_prefix = original_info.get_modules("prefix", prefix)
if not modules_matching_prefix:
pytest.skip("Requires {} runtime".format(prefix))
with threadpool_limits(limits={prefix: limit}):
for module in modules_matching_prefix:
if is_old_openblas(module):
continue
# threadpool_limits only sets an upper bound on the number of
# threads.
assert 0 < module.get_num_threads() <= limit