Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return Add([self, input])
else:
raise NotImplementedError
def __neg__(self):
return -1 * self
def __sub__(self, input):
return self.__add__(-input)
def __repr__(self):
return '<{oshape}x{ishape}> {repr_str} Linop>'.format(
oshape=self.oshape, ishape=self.ishape, repr_str=self.repr_str)
class Identity(Linop):
"""Identity linear operator.
Returns input directly.
Args:
shape (tuple of ints): Input shape
"""
def __init__(self, shape):
super().__init__(shape, shape)
def _apply(self, input):
return input
def _adjoint_linop(self):
def _apply(self, input):
device = backend.get_device(input)
data = backend.to_device(self.data, device)
with device:
return conv.convolve(data, input,
mode=self.mode, strides=self.strides,
multi_channel=self.multi_channel)
def _adjoint_linop(self):
return ConvolveFilterAdjoint(
self.ishape, self.data,
mode=self.mode, strides=self.strides,
multi_channel=self.multi_channel)
class ConvolveFilterAdjoint(Linop):
r"""Adjoint convolution operator for filter arrays.
Args:
filt_shape (tuple of ints): filter array shape:
:math:`[n_1, \ldots, n_D]` if multi_channel is False
:math:`[c_o, c_i, n_1, \ldots, n_D]` otherwise.
data (array): data array of shape:
:math:`[\ldots, m_1, \ldots, m_D]` if multi_channel is False,
:math:`[\ldots, c_i, m_1, \ldots, m_D]` otherwise.
mode (str): {'full', 'valid'}.
strides (None or tuple of ints): convolution strides of length D.
multi_channel (bool): specify if input/output has multiple channels.
"""
def __init__(self, filt_shape, data,
mode='full', strides=None,
raise Exception(
'Shapes must have the same lengths to concatenate.')
for i in range(ndim):
if i == axis:
ishape[i] += shape[i]
indices.append(idx)
idx += shape[i]
elif shape[i] != ishape[i]:
raise RuntimeError(
'Shapes not along axis must be the same to concatenate.')
return ishape, indices
class Hstack(Linop):
"""Horizontally stack linear operators.
Creates a Linop that splits the input, applies Linops independently,
and sums outputs.
In matrix form, this is equivalant to given matrices {A1, ..., An},
returns [A1, ..., An].
Input and output devices must be the same.
Args:
linops (list of Linops): list of linops with the same output shape.
axis (int or None): If None, inputs are vectorized and concatenated.
Otherwise, inputs are stacked along axis.
"""
super().__init__(oshape, ishape)
def _apply(self, input):
device = backend.get_device(input)
with device:
return fourier.nufft_adjoint(
input, self.coord, self.oshape,
oversamp=self.oversamp, width=self.width)
def _adjoint_linop(self):
return NUFFT(self.oshape, self.coord,
oversamp=self.oversamp, width=self.width)
class ConvolveData(Linop):
r"""Convolution operator for data arrays.
Args:
data_shape (tuple of ints): data array shape:
:math:`[\ldots, m_1, \ldots, m_D]` if multi_channel is False,
:math:`[\ldots, c_i, m_1, \ldots, m_D]` otherwise.
filt (array): filter array of shape:
:math:`[n_1, \ldots, n_D]` if multi_channel is False
:math:`[c_o, c_i, n_1, \ldots, n_D]` otherwise.
mode (str): {'full', 'valid'}.
strides (None or tuple of ints): convolution strides of length D.
multi_channel (bool): specify if input/output has multiple channels.
"""
def __init__(self, data_shape, filt, mode='full', strides=None,
multi_channel=False):
def _apply(self, input):
device = backend.get_device(input)
with device:
coord = backend.to_device(self.coord, device)
return interp.interpolate(input, coord,
kernel=self.kernel,
width=self.width, param=self.param)
def _adjoint_linop(self):
return Gridding(
self.ishape, self.coord,
kernel=self.kernel, width=self.width, param=self.param)
class Gridding(Linop):
"""Gridding linear operator.
Args:
oshape (tuple of ints): Output shape = batch_shape + pts_shape
ishape (tuple of ints): Input shape = batch_shape + grd_shape
coord (array): Coordinates, values from - nx / 2 to nx / 2 - 1.
ndim can only be 1, 2 or 3. of shape pts_shape + [ndim]
width (float): Width of interp. kernel in grid size.
kernel (str): Interpolation kernel, {'spline', 'kaiser_bessel'}.
param (float): Kernel parameter.
See Also:
:func:`sigpy.gridding`
"""
"""
I = Identity(ishape)
ndim = len(ishape)
axes = util._normalize_axes(axes, ndim)
linops = []
for i in axes:
D = I - Circshift(ishape, [1], axes=[i])
R = Reshape([1] + list(ishape), ishape)
linops.append(R * D)
G = Vstack(linops, axis=0)
return G
class NUFFT(Linop):
"""NUFFT linear operator.
Args:
ishape (tuple of int): Input shape.
coord (array): Coordinates, with values [-ishape / 2, ishape / 2]
oversamp (float): Oversampling factor.
width (float): Kernel width.
n (int): Kernel sampling number.
"""
def __init__(self, ishape, coord, oversamp=1.25, width=4):
self.coord = coord
self.oversamp = oversamp
self.width = width
ndim = coord.shape[-1]
if not (i == m or i == 1 or m == 1):
raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
ishape=ishape, mshape=mshape))
oshape.append(max(i, m))
if ishape_exp[-1] != mshape_exp[-2]:
raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
ishape=ishape, mshape=mshape))
oshape += [ishape_exp[-2], mshape_exp[-1]]
return oshape
class RightMatMul(Linop):
"""Matrix multiplication on the right.
Args:
ishape (tuple of ints): Input shape.
It must be able to broadcast with mat.shape.
mat (array): Matrix of shape [..., m, n]
adjoint (bool): Toggle adjoint.
If True, performs conj(mat).swapaxes(-1, -2)
before performing matrix multiplication.
"""
def __init__(self, ishape, mat, adjoint=False):
self.mat = mat
self.adjoint = adjoint
self.expanded_ishape.append(oshape[d])
self.reps.append(1)
super().__init__(oshape, ishape)
def _apply(self, input):
device = backend.get_device(input)
xp = device.xp
with device:
return xp.tile(input.reshape(self.expanded_ishape), self.reps)
def _adjoint_linop(self):
return Sum(self.oshape, self.axes)
class ArrayToBlocks(Linop):
"""Extract blocks from an array in a sliding window manner.
Args:
ishape (array): input array of shape [..., N_1, ..., N_D]
blk_shape (tuple): block shape of length D, with D <= 4.
blk_strides (tuple): block strides of length D.
See Also:
:func:`sigpy.block.array_to_blocks`
"""
def __init__(self, ishape, blk_shape, blk_strides):
self.blk_shape = blk_shape
self.blk_strides = blk_strides
D = len(blk_shape)
self.ishift = ishift
self.oshift = oshift
super().__init__(oshape, ishape)
def _apply(self, input):
with backend.get_device(input):
return util.resize(input, self.oshape,
ishift=self.ishift, oshift=self.oshift)
def _adjoint_linop(self):
return Resize(self.ishape, self.oshape,
ishift=self.oshift, oshift=self.ishift)
class Flip(Linop):
"""Flip linear operator.
Args:
shape (tuple of int): Input shape
"""
def __init__(self, shape, axes=None):
self.axes = axes
super().__init__(shape, shape)
def _apply(self, input):
device = backend.get_device(input)
with device:
return util.flip(input, self.axes)
def _apply(self, input):
return input.transpose(self.axes)
def _adjoint_linop(self):
if self.axes is None:
iaxes = None
oshape = self.ishape[::-1]
else:
iaxes = np.argsort(self.axes)
oshape = [self.ishape[a] for a in self.axes]
return Transpose(oshape, axes=iaxes)
class FFT(Linop):
"""FFT linear operator.
Args:
ishape (tuple of int): Input shape
axes (None or tuple of int): Axes to perform FFT.
If None, applies on all axes.
center (bool): Toggle center FFT.
"""
def __init__(self, shape, axes=None, center=True):
self.axes = axes
self.center = center
super().__init__(shape, shape)