Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_adam_poincare():
torch.manual_seed(44)
ideal = torch.tensor([0.5, 0.5])
start = torch.randn(2) / 2
start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0)
start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall())
def closure():
optim.zero_grad()
loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
loss.backward()
return loss.item()
optim = geoopt.optim.RiemannianAdam([start], lr=1e-2)
for _ in range(2000):
optim.step(closure)
np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceComplementIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
}
# shapes to verify unary element implementation
shapes = {
geoopt.manifolds.PoincareBall: (3,),
geoopt.manifolds.EuclideanStiefel: (10, 5),
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (1,),
geoopt.manifolds.Sphere: (10,),
geoopt.manifolds.SphereSubspaceIntersection: (10,),
geoopt.manifolds.SphereSubspaceComplementIntersection: (10,),
}
UnaryCase = collections.namedtuple(
"UnaryCase", "shape,x,ex,v,ev,manifold,manopt_manifold"
)
@pytest.fixture()
def unary_case(manifold):
shape = shapes[type(manifold)]
np.random.seed(42)
torch.manual_seed(43)
if type(manifold) in mannopt:
def poincare_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.PoincareBall]
ex = torch.randn(*shape, dtype=torch.float64) / 3
ev = torch.randn(*shape, dtype=torch.float64) / 3
x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
ex = x.clone()
v = ev.clone()
manifold = geoopt.PoincareBall().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def test_transport(unary_case, t):
if unary_case.manopt_manifold is None:
pytest.skip("pymanopt does not have {}".format(unary_case.manifold))
if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
pytest.skip("pymanopt uses euclidean Stiefel")
x = unary_case.x
v = unary_case.v
y = x.retr(v, t=t)
u = x.transp(v, u=v, t=t)
u_star = unary_case.manopt_manifold.transp(x.numpy(), y.numpy(), v.numpy())
np.testing.assert_allclose(u, u_star)
def euclidean_stiefel_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel]
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
u, _, v = torch.svd(ex)
x = u @ v.t()
nonsym = x.t() @ ev
v = ev - x @ (nonsym + nonsym.t()) / 2
manifold = geoopt.manifolds.EuclideanStiefel()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.manifolds.EuclideanStiefelExact()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
functools.partial(geoopt.manifolds.Stiefel, canonical=False),
functools.partial(geoopt.manifolds.Stiefel, canonical=True),
geoopt.manifolds.PoincareBall,
geoopt.manifolds.Euclidean,
geoopt.manifolds.Sphere,
functools.partial(
geoopt.manifolds.SphereSubspaceIntersection,
torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
),
functools.partial(
geoopt.manifolds.SphereSubspaceComplementIntersection,
torch.from_numpy(np.random.RandomState(42).randn(10, 3)),
),
],
)
def manifold(request, retraction_order):
man = request.param()
def euclidean_stiefel_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel]
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
u, _, v = torch.svd(ex)
x = u @ v.t()
nonsym = x.t() @ ev
v = ev - x @ (nonsym + nonsym.t()) / 2
manifold = geoopt.manifolds.EuclideanStiefel()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.manifolds.EuclideanStiefelExact()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
@pytest.fixture(autouse=True, params=[1, 2, 3, 4, 5])
def seed(request):
torch.manual_seed(request.param)
yield
@pytest.fixture(autouse=True, params=[torch.float64], ids=lambda t: str(t))
def use_floatX(request):
dtype_old = torch.get_default_dtype()
torch.set_default_dtype(request.param)
yield request.param
torch.set_default_dtype(dtype_old)
manifold_shapes = {
geoopt.manifolds.PoincareBall: (3,),
geoopt.manifolds.EuclideanStiefel: (10, 5),
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (10,),
geoopt.manifolds.Sphere: (10,),
geoopt.manifolds.SphereExact: (10,),
geoopt.manifolds.ProductManifold: (10 + 3 + 6 + 1,),
}
UnaryCase = collections.namedtuple("UnaryCase", "shape,x,ex,v,ev,manifold")
def canonical_stiefel_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel]
ex = torch.randn(*shape)
],
)
def manifold(request, retraction_order):
man = request.param()
try:
return man.set_default_order(retraction_order).double()
except ValueError:
pytest.skip("not supported retraction order for {}".format(man))
mannopt = {
geoopt.manifolds.EuclideanStiefel: pymanopt.manifolds.Stiefel,
geoopt.manifolds.CanonicalStiefel: pymanopt.manifolds.Stiefel,
geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean,
geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere,
geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceComplementIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
}
# shapes to verify unary element implementation
shapes = {
geoopt.manifolds.PoincareBall: (3,),
geoopt.manifolds.EuclideanStiefel: (10, 5),
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (1,),
geoopt.manifolds.Sphere: (10,),
geoopt.manifolds.Euclidean: pymanopt.manifolds.Euclidean,
geoopt.manifolds.Sphere: pymanopt.manifolds.Sphere,
geoopt.manifolds.SphereSubspaceIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
geoopt.manifolds.SphereSubspaceComplementIntersection: functools.partial(
pymanopt.manifolds.SphereSubspaceComplementIntersection,
U=np.random.RandomState(42).randn(10, 3),
),
}
# shapes to verify unary element implementation
shapes = {
geoopt.manifolds.PoincareBall: (3,),
geoopt.manifolds.EuclideanStiefel: (10, 5),
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (1,),
geoopt.manifolds.Sphere: (10,),
geoopt.manifolds.SphereSubspaceIntersection: (10,),
geoopt.manifolds.SphereSubspaceComplementIntersection: (10,),
}
UnaryCase = collections.namedtuple(
"UnaryCase", "shape,x,ex,v,ev,manifold,manopt_manifold"
)
@pytest.fixture()
def unary_case(manifold):
shape = shapes[type(manifold)]
np.random.seed(42)