Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def replace_op(self, graph: Graph, node: Node):
# Create nodes
const_neg = Const(graph, dict(value=np.array(-1), name=node.name + '/negate_const_')).create_node()
negate = Mul(graph, {'name': node.name + '/negate_'}).create_node()
add = Add(graph, {'name': node.name + '/add_'}).create_node()
const = Const(graph, {'value': np.array(2)}).create_node()
squared = Pow(graph, {'name': node.name + '/squared_'}).create_node()
# Connect nodes
node.in_port(0).get_connection().set_destination(add.in_port(0))
node.in_port(1).get_connection().set_destination(negate.in_port(0))
const_neg.out_port(0).connect(negate.in_port(1))
negate.out_port(0).connect(add.in_port(1))
add.out_port(0).connect(squared.in_port(0))
const.out_port(0).connect(squared.in_port(1))
# The "explicit" version of the return value is: [(out_node.id, 0)])
return [squared.id]
from mo.graph.graph import Node, Graph
from extensions.ops.elementwise import Add, Maximum, Mul
class EltwiseNReplacement(FrontReplacementOp):
"""
This replacer substitutes elementwise operation with more than 2 inputs with a number of simple elementwise
operations with 2 inputs. The replacer supports operations supported by the Eltwise layer.
"""
op = 'EltwiseN'
enabled = True
op_to_class_map = {
'sum': Add,
'max': Maximum,
'mul': Mul,
}
def replace_op(self, graph: Graph, node: Node):
last_node = node
operation = node.operation
assert operation in EltwiseNReplacement.op_to_class_map
op_class = EltwiseNReplacement.op_to_class_map[operation]
left_connect = node.in_port(0).get_connection()
for ind in list(node.in_ports())[1:]:
attrs = {'name': node.name + '/' + operation + '_' + str(ind)}
attrs.update({'axis': node.axis} if node.has_valid('axis') else {})
# Create node
eltwise_op = op_class(graph, attrs).create_node()
# Connect nodes
left_connect.set_destination(eltwise_op.in_port(0))
# Check that weights and biases are not useless
has_bias, has_weights = True, True
if all([x == 1 for x in np.nditer(op.scale)]):
has_weights = False
if all([x == 0 for x in np.nditer(op.bias)]):
has_bias = False
assert len(op.in_ports()) == 1
last_port = op.in_port(0).get_source()
# Create Mul & Add nodes
if has_weights:
mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node()
mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
mul_weights.out_port(0).connect(mul_op.in_port(1))
last_port = mul_op.out_port(0)
if has_bias:
add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node()
add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
last_port.get_connection().set_destination(add_op.in_port(0))
add_bias.out_port(0).connect(add_op.in_port(1))
last_port = add_op.out_port(0)
op.in_port(0).disconnect()
op.out_port(0).get_connection().set_source(last_port)
reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
dict(name='do_reshape_classes'),
match.single_input_node(1)[0])
priors_node = match.single_input_node(2)[0]
placeholder = [Node(graph, node_id) for node_id in graph.nodes() if Node(graph, node_id).op == 'Parameter'][0]
im_height = placeholder.shape[1]
im_width = placeholder.shape[2]
# scale prior boxes to the [0, 1] interval
priors_scale_const_node = Const(graph, {'value': np.array([1 / im_width,
1 / im_height,
1 / im_width,
1 / im_height])}).create_node([])
priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node(
[priors_node, priors_scale_const_node])
# calculate prior boxes widths and heights
split_node = SplitV(graph, {'axis': 2, 'size_splits': [1, 1, 1, 1],
'out_ports_count': 4}).create_node([priors_scale_node])
priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_')
).create_node([(split_node, 2), (split_node, 0)])
priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')
).create_node([(split_node, 3), (split_node, 1)])
# concat weights and heights into a single tensor and multiple with the box coordinates regression values
concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height',
'axis': -1, 'in_ports_count': 4}).create_node(
[priors_width_node, priors_height_node, priors_width_node, priors_height_node])
applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node(
def replace_pattern(self, graph: Graph, match: dict):
node = match['minimum']
# Constant propagation case
if node.in_node(0).value is not None and node.in_node(1).value is not None:
return
neg_1_const = Const(graph, dict(value=np.array(-1), name=node.name + '/negate1_const'))
neg_2_const = Const(graph, dict(value=np.array(-1), name=node.name + '/negate2_const'))
negate_1 = Mul(graph, dict(name=node.name + '/negate1_'))
negate_2 = Mul(graph, dict(name=node.name + '/negate2_'))
maximum = Maximum(graph, dict(name=node.name + '/Max_'))
negate_output_const = Const(graph, dict(value=np.array(-1), name=node.name + '/negate_out_const_'))
negate_output = Mul(graph, dict(scale=-1, name=node.name + '/negate_out_'))
negate_output.create_node_with_data(
inputs=[
maximum.create_node_with_data(
[negate_1.create_node_with_data([node.in_node(0), neg_1_const.create_node_with_data()]),
negate_2.create_node_with_data([node.in_node(1), neg_2_const.create_node_with_data()])]),
negate_output_const.create_node_with_data()
],
data_nodes=node.out_node())
# Delete minimum vertex
node.graph.remove_node(node.id)
def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
if not node.has_valid('bias') or (node.has_valid('bias') and node.bias == 1):
return
# Calculate scale value & create Const op
scale_value = np.array(1. / (pow(node.bias, node.beta)))
node.alpha /= node.bias
const_node = Const(graph, {'value': scale_value, 'shape': scale_value.shape,
'name': node.name + "/Const_Mul_"}).create_node()
# Create Mul node
mul_node = Mul(graph, {'name': node.name + "/Mul_"}).create_node()
# Connect nodes
const_node.out_port(0).connect(mul_node.in_port(1))
node.out_port(0).get_connection().set_source(mul_node.out_port(0))
node.out_port(0).connect(mul_node.in_port(0))
# Delete bias, if it is not deleted it will appear in IR v6
del node['bias']
def extract(node):
Mul.update_node_stat(node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)})
return __class__.enabled
log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
return
log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))
MVN = Op.get_op_class_by_name('MVN')
mvn = MVN(graph, dict(
name=fbn.name + '/MVN_',
eps=fbn.eps,
required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3]
))
mvn.attrs['old_infer'] = mvn.attrs['infer']
mvn.attrs['infer'] = __class__.infer
mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))
input_gamma = fbn.in_node(1)
input_beta = fbn.in_node(2)
mean_reduction = match['mean'].in_node(1)
variance_reduction = match['variance'].in_node(1)
new_subgraph = add.create_node([
mul.create_node([
mvn.create_node([input, mean_reduction, variance_reduction]),
input_gamma
]),
input_beta
])
fbn.replace_node(new_subgraph)
gemm.out_port(0).connect(bias_node.in_port(0))
if graph.graph['cmd_params'].generate_experimental_IR_V10:
gemm.type = 'MatMul'
if gemm.has_valid('alpha'):
if not math.isclose(gemm.alpha, 1):
mul_node = Mul(graph, {'name': 'MatMulAlpha_'}).create_node()
const = Const(graph, {'value': np.array(gemm.alpha)}).create_node()
bias_node.in_port(0).get_connection().set_destination(mul_node.in_port(0))
bias_node.in_port(0).connect(mul_node.out_port(0))
mul_node.in_port(1).connect(const.out_port(0))
del gemm['alpha']
if gemm.has_valid('beta'):
if not math.isclose(gemm.beta, 1):
mul_node = Mul(graph, {'name': 'MatMulBeta_'}).create_node()
const = Const(graph, {'value': np.array(gemm.beta)}).create_node()
bias_node.in_port(1).get_connection().set_destination(mul_node.in_port(0))
bias_node.in_port(1).connect(mul_node.out_port(0))
mul_node.in_port(1).connect(const.out_port(0))
del gemm['beta']
if not graph.graph['cmd_params'].generate_experimental_IR_V10:
assign_dims_to_weights(gemm.in_node(1), None, 1, 0, 2)
# Do not transpose weights in this pass, it will be done as a separate pass
def extract(node: Node):
Mul.update_node_stat(node)
return __class__.enabled