How to use the trax.layers.base.layer function in trax

To help you get started, we’ve selected a few trax examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google / trax / trax / layers / combinators.py View on Github external
@base.layer(n_in=3)
def Gate(xs, **unused_kwargs):
  """Implements a gating function on a (memory, gate, candidate) tuple.

  Final update is memory * gate + (1-gate) * candidate

  This gating equation may also be referred to as Highway Network.
  Highway Networks: https://arxiv.org/abs/1505.00387

  Args:
    xs: A tuple of memory, gate, candidate

  Returns:
    The result of applying gating.
  """
  state, gate, candidate = xs
  return gate * state + (1.0 - gate) * candidate
github google / trax / trax / layers / core.py View on Github external
@base.layer()
def LogSoftmax(x, axis=-1, **unused_kwargs):
  """Apply log softmax to x: log-normalize along the given axis."""
  return x - backend.logsumexp(x, axis, keepdims=True)
github google / trax / trax / layers / core.py View on Github external
@base.layer()
def ParametricRelu(x, a=1., **unused_kwargs):
  return np.maximum(a * x, np.zeros_like(x))
github google / trax / trax / layers / combinators.py View on Github external
@base.layer(n_in=2, n_out=2)
def Swap(xs, **unused_kwargs):
  """Swaps the top two stack elements."""
  return (xs[1], xs[0])
github google / trax / trax / layers / core.py View on Github external
@base.layer()
def LeakyRelu(x, a=0.01, **unused_kwargs):
  return np.where(x >= 0, x, a * x)
github google / trax / trax / layers / metrics.py View on Github external
@base.layer(n_in=2, n_out=1)
def Accuracy(x, axis=-1, **kw):
  del kw
  prediction, target = x
  predicted_class = np.argmax(prediction, axis=axis)
  return np.equal(predicted_class, target)
github google / trax / trax / layers / rnn.py View on Github external
@base.layer(n_in=3, n_out=2)
def InnerSRUCell(x, **unused_kwargs):
  """The inner (non-parallel) computation of an SRU."""
  cur_x_times_one_minus_f, cur_f, cur_state = x
  res = cur_f * cur_state + cur_x_times_one_minus_f
  return res, res
github google / trax / trax / layers / core.py View on Github external
@base.layer()
def Relu(x, **unused_kwargs):
  return np.maximum(x, np.zeros_like(x))
github google / trax / trax / layers / core.py View on Github external
@base.layer()
def Elu(x, a=1., **unused_kwargs):
  return np.where(x > 0, x, a * np.expm1(x))
github google / trax / trax / layers / combinators.py View on Github external
@base.layer(n_in=0)
def FlattenList(xs, **unused_kwargs):
  """Flatten lists."""
  # TODO(jonni): Consider renaming layer to DeepFlatten.
  return tuple(_deep_flatten(xs))