Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_dtype():
# Test `Dense`.
assert B.dtype(Dense(np.array([[1]]))) == np.int64
assert B.dtype(Dense(np.array([[1.0]]))) == np.float64
# Test `Diagonal`.
diag_int = Diagonal(np.array([1]))
diag_float = Diagonal(np.array([1.0]))
assert B.dtype(diag_int) == np.int64
assert B.dtype(diag_float) == np.float64
# Test `LowRank`.
lr_int = LowRank(left=np.array([[1]]),
right=np.array([[2]]),
middle=np.array([[3]]))
lr_float = LowRank(left=np.array([[1.0]]),
right=np.array([[2.0]]),
middle=np.array([[3.0]]))
assert B.dtype(lr_int) == np.int64
assert B.dtype(lr_float) == np.float64
# Test `Constant`.
assert B.dtype(Constant(1, rows=1)) == int
assert B.dtype(Constant(1.0, rows=1)) == float
# Test `Woodbury`.
assert B.dtype(diag_int) == np.int64
assert B.dtype(diag_float) == np.float64
# Test `LowRank`.
lr_int = LowRank(left=np.array([[1]]),
right=np.array([[2]]),
middle=np.array([[3]]))
lr_float = LowRank(left=np.array([[1.0]]),
right=np.array([[2.0]]),
middle=np.array([[3.0]]))
assert B.dtype(lr_int) == np.int64
assert B.dtype(lr_float) == np.float64
# Test `Constant`.
assert B.dtype(Constant(1, rows=1)) == int
assert B.dtype(Constant(1.0, rows=1)) == float
# Test `Woodbury`.
assert B.dtype(Woodbury(diag_int, lr_int)) == np.int64
assert B.dtype(Woodbury(diag_float, lr_float)) == np.float64
def sum(a, axis=None):
# Efficiently handle a number of common cases.
if axis is None:
return B.sum(B.diag(a))
elif axis is 0:
return B.concat(B.diag(a),
B.zeros(B.dtype(a), B.shape(a)[1] - B.diag_len(a)),
axis=0)
elif axis is 1:
return B.concat(B.diag(a),
B.zeros(B.dtype(a), B.shape(a)[0] - B.diag_len(a)),
axis=0)
else:
# Fall back to generic implementation.
return B.sum.invoke(Dense)(a, axis=axis)
def diag(a):
# Append zeros or remove elements as necessary.
diag_len = B.diag_len(a)
extra_zeros = B.maximum(diag_len - B.shape(a.diag)[0], 0)
return B.concat(a.diag[:diag_len], B.zeros(B.dtype(a), extra_zeros), axis=0)
def elwise(self, x, y):
return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)
def perturb(x):
"""Slightly perturb a tensor.
Args:
x (tensor): Tensor to perturb.
Returns:
tensor: `x`, but perturbed.
"""
dtype = convert(B.dtype(x), B.NPDType)
if dtype == np.float64:
return 1e-20 + x * (1 + 1e-14)
elif dtype == np.float32:
return 1e-20 + x * (1 + 1e-7)
else:
raise ValueError('Cannot perturb a tensor of data type {}.'
''.format(B.dtype(x)))
def elwise(self, x, y):
w_x, w_y = x.w, y.w
x, y = x.get(), y.get()
if x is y:
return B.uprank(1 / w_x)
else:
return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)
b = B.transpose(b) if tr_b else b
# Get shape of `b`.
b_rows, b_cols = B.shape(b)
# If `b` is square, don't do complicated things.
if b_rows == b_cols and b_rows is not None:
return dense(a) * B.diag(b)[None, :]
# Compute the core part.
cols = B.minimum(B.shape(a)[1], b_cols)
core = dense(a)[:, :cols] * B.diag(b)[None, :cols]
# Compute extra zeros to be appended.
extra_cols = b_cols - cols
extra_zeros = B.zeros(B.dtype(b), B.shape(a)[0], extra_cols)
return B.concat(core, extra_zeros, axis=1)
def diag(diag, rows, cols=None):
cols = rows if cols is None else cols
# Cut the diagonal to accommodate the size.
diag = diag[:B.minimum(rows, cols)]
diag_len, dtype = B.shape(diag)[0], B.dtype(diag)
# PyTorch incorrectly handles dimensions of size 0. Therefore, if the
# numbers of extra columns and rows are `Number`s, which will be the case if
# PyTorch is the backend, then perform a check to prevent appending tensors
# with dimensions of size 0.
# Start with just a diagonal matrix.
res = B.diag(diag)
# Pad extra columns if necessary.
extra_cols = cols - diag_len
if not (isinstance(extra_cols, Number) and extra_cols == 0):
zeros = B.zeros(dtype, diag_len, extra_cols)
res = B.concat(B.diag(diag), zeros, axis=1)
# Pad extra rows if necessary.