Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
try:
from jaxlib import cublas_kernels
for _name, _value in cublas_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
pass
try:
from jaxlib import cusolver_kernels
for _name, _value in cusolver_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
pass
_Shape = xla_client.Shape
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
if dtype == np.float32:
return np.float32
elif dtype == np.float64:
return np.float64
elif dtype == np.complex64:
return np.float32
elif dtype == np.complex128:
return np.float64
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
_prod = lambda xs: reduce(operator.mul, xs, 1)