Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tensor_is_attached():
m1 = geoopt.Euclidean()
p = m1.random(())
assert m1.is_attached(p)
def test_ismanifold():
m1 = geoopt.Euclidean()
assert geoopt.ismanifold(m1, geoopt.Euclidean)
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
assert geoopt.ismanifold(m1, geoopt.Euclidean)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, int)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, 1)
assert not geoopt.ismanifold(1, geoopt.Euclidean)
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
product_manifold.pack_point(*ev),
product_manifold,
)
# + 1 case without stiefel
torch.manual_seed(42)
ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
manifolds = [
geoopt.Sphere(),
geoopt.PoincareBall(),
# geoopt.Stiefel(),
geoopt.Euclidean(),
]
x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]
product_manifold = geoopt.ProductManifold(
*((manifolds[i], ex[i].shape) for i in range(len(ex)))
)
yield UnaryCase(
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
product_manifold.pack_point(*ev),
product_manifold,
)
def test_tensor_is_attached():
m1 = geoopt.Euclidean()
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
p = m1.random(())
assert m1.is_attached(p)
def test_ismanifold():
m1 = geoopt.Euclidean()
assert geoopt.ismanifold(m1, geoopt.Euclidean)
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
assert geoopt.ismanifold(m1, geoopt.Euclidean)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, int)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, 1)
assert not geoopt.ismanifold(1, geoopt.Euclidean)
def test_inner_product():
pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)), (Euclidean(), ()))
point = [
Sphere().random_uniform(5, 10),
Sphere().random_uniform(5, 3, 2),
Euclidean().random_normal(5),
]
tensor = pman.pack_point(*point)
tangent = torch.randn_like(tensor)
tangent = pman.proju(tensor, tangent)
inner = pman.inner(tensor, tangent)
assert inner.shape == (5,)
inner_kd = pman.inner(tensor, tangent, keepdim=True)
assert inner_kd.shape == (5, 1)
def test_product():
manifold = geoopt.ProductManifold(
(geoopt.Sphere(), 10),
(geoopt.PoincareBall(), 3),
(geoopt.Stiefel(), (20, 2)),
(geoopt.Euclidean(), 43),
)
sample = manifold.random(20, manifold.n_elements)
manifold.assert_check_point_on_manifold(sample)
def test_from_point_checks_shapes():
point = [
Sphere().random_uniform(5, 10),
Sphere().random_uniform(3, 3, 2),
Euclidean().random_normal(5),
]
pman = ProductManifold.from_point(*point)
assert pman.n_elements == (5 * 10 + 3 * 3 * 2 + 5 * 1)
with pytest.raises(ValueError) as e:
_ = ProductManifold.from_point(*point, batch_dims=1)
assert e.match("Not all parts have same batch shape")
def test_fails_Euclidean():
with pytest.raises(ValueError):
manifold = geoopt.Euclidean(ndim=1)
manifold.random_normal(())