How to use the gast.Param function in gast

To help you get started, we’ve selected a few gast examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tensorflow / tensorflow / python / autograph / pyct / testing / codegen.py View on Github external
        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
    args = gast.arguments(arg_vars, None, [], [], None, [])
github serge-sans-paille / pythran / pythran / transformations / remove_comprehension.py View on Github external
def visit_GeneratorExp(self, node):
        self.update = True
        node.elt = self.visit(node.elt)
        name = "generator_expression{0}".format(self.count)
        self.count += 1
        args = self.gather(ImportedIds, node)
        self.count_iter = 0

        body = reduce(self.nest_reducer,
                      reversed(node.generators),
                      ast.Expr(ast.Yield(node.elt))
                      )

        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [body], [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
github zylo117 / tensorflow-gpu-macosx / tensorflow / contrib / autograph / pyct / static_analysis / type_info.py View on Github external
def visit_Name(self, node):
    self.generic_visit(node)
    qn = anno.getanno(node, anno.Basic.QN)
    if isinstance(node.ctx, gast.Param):
      self._process_function_arg(qn)
    elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
      # E.g. if we had
      # a = b
      # then for future references to `a` we should have definition = `b`
      definition = self.scope.getval(qn)
      if anno.hasanno(definition, 'type'):
        anno.setanno(node, 'type', anno.getanno(definition, 'type'))
        anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
      if anno.hasanno(definition, 'element_type'):
        anno.setanno(node, 'element_type',
                     anno.getanno(definition, 'element_type'))
    return node
github serge-sans-paille / pythran / pythran / optimizations / comprehension_patterns.py View on Github external
if len(iters) == 1:
                iterAST = iters[0]
                varAST = ast.arguments([variables[0]], [], None, [], [], None, [])
            else:
                self.use_itertools = True
                prodName = ast.Attribute(
                    value=ast.Name(id=mangle('itertools'),
                                   ctx=ast.Load(),
                                   annotation=None, type_comment=None),
                    attr='product', ctx=ast.Load())

                varid = variables[0].id  # retarget this id, it's free
                renamings = {v.id: (i,) for i, v in enumerate(variables)}
                node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
                iterAST = ast.Call(prodName, iters, [])
                varAST = ast.arguments([ast.Name(varid, ast.Param(), None, None)],
                                       [], None, [], [], None, [])

            ldBodymap = node.elt
            ldmap = ast.Lambda(varAST, ldBodymap)

            return make_attr(ldmap, iterAST)

        else:
            return self.generic_visit(node)
github serge-sans-paille / pythran / pythran / transformations / remove_lambdas.py View on Github external
if MODULES['functools'] not in self.global_declarations.values():
            import_ = ast.Import([ast.alias('functools', mangle('functools'))])
            self.imports.append(import_)
            functools_module = MODULES['functools']
            self.global_declarations[mangle('functools')] = functools_module

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(
            self.prefix,
            len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)]
        node.args.args = ([ast.Name(iin, ast.Param(), None, None)
                           for iin in sorted(ii)] +
                          node.args.args)
        forged_fdef = ast.FunctionDef(
            forged_name,
            copy(node.args),
            [ast.Return(node.body)],
            [], None, None)
        metadata.add(forged_fdef, metadata.Local())
        self.lambda_functions.append(forged_fdef)
        self.global_declarations[forged_name] = forged_fdef
        proxy_call = ast.Name(forged_name, ast.Load(), None, None)
        if binded_args:
            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial",
github serge-sans-paille / beniget / beniget / beniget.py View on Github external
def visit_Name(self, node):

        if isinstance(node.ctx, (ast.Param, ast.Store)):
            dnode = self.chains.setdefault(node, Def(node))
            if node.id in self._promoted_locals[-1]:
                self.extend_definition(node.id, dnode)
                if dnode not in self.locals[self.module]:
                    self.locals[self.module].append(dnode)
            else:
                self.set_definition(node.id, dnode)
                if dnode not in self.locals[self._currenthead[-1]]:
                    self.locals[self._currenthead[-1]].append(dnode)

            if node.annotation is not None:
                self.visit(node.annotation)

        elif isinstance(node.ctx, (ast.Load, ast.Del)):
            node_in_chains = node in self.chains
            if node_in_chains:
github serge-sans-paille / pythran / pythran / transformations / remove_comprehension.py View on Github external
)
                      )
        # add extra metadata to this node
        metadata.add(body, metadata.Comprehension(starget))
        init = ast.Assign(
            [ast.Name(starget, ast.Store(), None, None)],
            ast.Call(
                ast.Attribute(
                    ast.Name('__builtin__', ast.Load(), None, None),
                    comp_type,
                    ast.Load()
                    ),
                [], [],)
            )
        result = ast.Return(ast.Name(starget, ast.Load(), None, None))
        sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
        fd = ast.FunctionDef(name,
                             ast.arguments(sargs, [], None, [], [], None, []),
                             [init, body, result],
                             [], None, None)
        metadata.add(fd, metadata.Local())
        self.ctx.module.body.append(fd)
        return ast.Call(
            ast.Name(name, ast.Load(), None, None),
            [ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
            [],
            )  # no sharing !
github google / tangent / tangent / reverse_ad.py View on Github external
# We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    if self.check_dims:

      def shape_match_template(primal, adjoint):
        assert tangent.shapes_match(
            primal, adjoint
        ), 'Shape mismatch between return value (%s) and seed derivative (%s)' % (
            numpy.shape(primal), numpy.shape(adjoint))

      shape_check = template.replace(shape_match_template, primal=y, adjoint=dy)
      adjoint_body = shape_check + adjoint_body

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
github serge-sans-paille / pythran / pythran / optimizations / list_comp_to_map.py View on Github external
def make_Iterator(self, gen):
        if gen.ifs:
            ldFilter = ast.Lambda(
                ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
                              None, [], [], None, []),
                ast.BoolOp(ast.And(), gen.ifs)
                if len(gen.ifs) > 1 else gen.ifs[0])
            self.use_itertools |= MODULE == 'itertools'
            ifilterName = ast.Attribute(
                value=ast.Name(id=ASMODULE,
                               ctx=ast.Load(),
                               annotation=None),
                attr=IFILTER, ctx=ast.Load())
            return ast.Call(ifilterName, [ldFilter, gen.iter], [])
        else:
            return gen.iter
github serge-sans-paille / pythran / pythran / optimizations / comprehension_patterns.py View on Github external
def make_Iterator(self, gen):
        if gen.ifs:
            ldFilter = ast.Lambda(
                ast.arguments([ast.Name(gen.target.id, ast.Param(), None, None)],
                              [], None, [], [], None, []),
                ast.BoolOp(ast.And(), gen.ifs)
                if len(gen.ifs) > 1 else gen.ifs[0])
            self.use_itertools |= MODULE == 'itertools'
            ifilterName = ast.Attribute(
                value=ast.Name(id=ASMODULE,
                               ctx=ast.Load(),
                               annotation=None, type_comment=None),
                attr=IFILTER, ctx=ast.Load())
            return ast.Call(ifilterName, [ldFilter, gen.iter], [])
        else:
            return gen.iter