Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return x + t * u
def _inner(self, x, u, v):
return u * v
def _proju(self, x, u):
return u
def _projx(self, x):
return x
def _transp(self, x, u, v, t):
return v
class Stiefel(Manifold):
name = "Stiefel"
ndim = 2
reversible = True
def check_dims(self, x):
return x.dim() >= 2
def amat(self, x, u, project=True):
if project:
u = self.proju(x, u)
return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)
def _proju(self, x, u):
p = -0.5 * x @ x.transpose(-1, -2)
p[..., range(x.shape[-2]), range(x.shape[-2])] += 1
return p @ u
Parameters
----------
instance : geoopt.Manifold
check if a given manifold is compatible with cls API
cls : type
manifold type
Returns
-------
bool
comparison result
"""
if not issubclass(cls, geoopt.manifolds.Manifold):
raise TypeError("`cls` should be a subclass of geoopt.manifolds.Manifold")
if not isinstance(instance, geoopt.manifolds.Manifold):
return False
else:
# this is the case to care about, Scaled class is a proxy, but fails instance checks
while isinstance(instance, geoopt.Scaled):
instance = instance.base
return isinstance(instance, cls)
@abc.abstractmethod
def _proju(self, x, u):
raise NotImplementedError
@abc.abstractmethod
def _projx(self, x):
raise NotImplementedError
def __repr__(self):
return self.name + " manifold"
def __eq__(self, other):
return type(self) is type(other)
class Rn(Manifold):
name = "Rn"
ndim = 0
reversible = True
def check_dims(self, x):
return True
def _retr(self, x, u, t):
return x + t * u
def _inner(self, x, u, v):
return u * v
def _proju(self, x, u):
return u
@insert_docs(Manifold.proju.__doc__, r"\s+x : .+\n.+", "")
def proju(self, u: torch.Tensor, **kwargs) -> torch.Tensor:
return self.manifold.proju(self, u, **kwargs)