Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_diag():
# Test `Dense`.
a = np.random.randn(5, 3)
allclose(B.diag(Dense(a)), np.diag(a))
# Test `Diagonal`.
allclose(B.diag(Diagonal(np.array([1, 2, 3]))), [1, 2, 3])
allclose(B.diag(Diagonal(np.array([1, 2, 3]), 2)), [1, 2])
allclose(B.diag(Diagonal(np.array([1, 2, 3]), 4)), [1, 2, 3, 0])
# Test `LowRank`.
b = np.random.randn(10, 3)
allclose(B.diag(LowRank(left=a, right=a)), np.diag(a.dot(a.T)))
allclose(B.diag(LowRank(left=a, right=b)), np.diag(a.dot(b.T)))
allclose(B.diag(LowRank(left=b, right=b)), np.diag(b.dot(b.T)))
# Test `Constant`.
allclose(B.diag(Constant(1, rows=3, cols=5)), np.ones(3))
# Test `Woodbury`.
p = GP(EQ(), graph=model) # 1D
p2 = p.select(0) # 2D
n = 5
x = np.linspace(0, 10, n)[:, None]
x1 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
x2 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
y = p2(x).sample()
post = p.condition(p2(x1), y)
allclose(post(x).mean, y)
assert abs_err(B.diag(post(x).var)) <= 1e-10
post = p.condition(p2(x2), y)
allclose(post(x).mean, y)
assert abs_err(B.diag(post(x).var)) <= 1e-10
post = p2.condition(p(x), y)
allclose(post(x1).mean, y)
allclose(post(x2).mean, y)
assert abs_err(B.diag(post(x1).var)) <= 1e-10
assert abs_err(B.diag(post(x2).var)) <= 1e-10
def dense(a): return B.diag(B.diag(a), *B.shape(a))
def marginals(self):
"""Get the marginals.
Returns:
tuple: A tuple containing the predictive means and lower and
upper 95% central credible interval bounds.
"""
mean = B.squeeze(self.mean)
if self.p is None:
vars = B.diag(self.var)
else:
vars = B.squeeze(B.dense(self.p.kernel.elwise(self.x)))
error = 2 * B.sqrt(vars)
return mean, mean - error, mean + error
def diag(a): return B.diag(a.lr) + B.diag(a.diag)
def diag(diag, rows, cols=None):
cols = rows if cols is None else cols
# Cut the diagonal to accommodate the size.
diag = diag[:B.minimum(rows, cols)]
diag_len, dtype = B.shape(diag)[0], B.dtype(diag)
# PyTorch incorrectly handles dimensions of size 0. Therefore, if the
# numbers of extra columns and rows are `Number`s, which will be the case if
# PyTorch is the backend, then perform a check to prevent appending tensors
# with dimensions of size 0.
# Start with just a diagonal matrix.
res = B.diag(diag)
# Pad extra columns if necessary.
extra_cols = cols - diag_len
if not (isinstance(extra_cols, Number) and extra_cols == 0):
zeros = B.zeros(dtype, diag_len, extra_cols)
res = B.concat(B.diag(diag), zeros, axis=1)
# Pad extra rows if necessary.
extra_rows = rows - diag_len
if not (isinstance(extra_rows, Number) and extra_rows == 0):
zeros = B.zeros(dtype, extra_rows, diag_len + extra_cols)
res = B.concat(res, zeros, axis=0)
return res
diag = diag[:B.minimum(rows, cols)]
diag_len, dtype = B.shape(diag)[0], B.dtype(diag)
# PyTorch incorrectly handles dimensions of size 0. Therefore, if the
# numbers of extra columns and rows are `Number`s, which will be the case if
# PyTorch is the backend, then perform a check to prevent appending tensors
# with dimensions of size 0.
# Start with just a diagonal matrix.
res = B.diag(diag)
# Pad extra columns if necessary.
extra_cols = cols - diag_len
if not (isinstance(extra_cols, Number) and extra_cols == 0):
zeros = B.zeros(dtype, diag_len, extra_cols)
res = B.concat(B.diag(diag), zeros, axis=1)
# Pad extra rows if necessary.
extra_rows = rows - diag_len
if not (isinstance(extra_rows, Number) and extra_rows == 0):
zeros = B.zeros(dtype, extra_rows, diag_len + extra_cols)
res = B.concat(res, zeros, axis=0)
return res
def matmul(a, b, tr_a=False, tr_b=False):
a = B.transpose(a) if tr_a else a
b = B.transpose(b) if tr_b else b
diag_len = B.minimum(B.diag_len(a), B.diag_len(b))
return Diagonal(B.diag(a)[:diag_len] * B.diag(b)[:diag_len],
rows=B.shape(a)[0],
cols=B.shape(b)[1])
def sum(a, axis=None):
# Efficiently handle a number of common cases.
if axis is None:
return B.sum(B.diag(a))
elif axis is 0:
return B.concat(B.diag(a),
B.zeros(B.dtype(a), B.shape(a)[1] - B.diag_len(a)),
axis=0)
elif axis is 1:
return B.concat(B.diag(a),
B.zeros(B.dtype(a), B.shape(a)[0] - B.diag_len(a)),
axis=0)
else:
# Fall back to generic implementation.
return B.sum.invoke(Dense)(a, axis=axis)