Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if isinstance(value, attr_value_pb2.AttrValue.ListValue):
result = list(value.ListFields()[0][1])
for i in range(len(result)):
if isinstance(result[i], int):
result[i] = int(result[i])
try:
if isinstance(result[i], long):
result[i] = int(result[i])
except:
pass
return result
else:
return value
class TFGraph(Graph):
def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model)
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV']
self.tf_data_format = data_format
def build(self):
for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
in_node = in_node.replace('/',
'_').replace('-',
'_').replace('^', '')
super(CaffeGraphNode,
self).__init__(layer,
layer.name.replace('/', '_').replace('-', '_'))
else:
super(CaffeGraphNode,
self).__init__(layer,
layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = type_str
self.fluid_code = FluidCode()
self.data = None
def set_params(self, params):
self.data = params
class CaffeGraph(Graph):
def __init__(self, model, params, caffe_pb):
self.params = params
self.caffe_pb = caffe_pb
super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers):
'''Filter out layers based on the current phase.'''
phase_map = {0: 'train', 1: 'test'}
filtered_layer_names = set()
filtered_layers = []
for layer in layers:
if hasattr(layer, 'input'):
continue
type_str = self.get_layer_type(layer)
phase = 'test'
if len(layer.include):
values = self.layer.dims
out_shapes = list()
out_shapes.append(values)
return out_shapes
@property
def dtype(self):
if isinstance(self.layer, ValueInfoProto):
dtype = self.layer.type.tensor_type.elem_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
else:
dtype = self.layer.data_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
class ONNXGraph(Graph):
def __init__(self, onnx_model):
super(ONNXGraph, self).__init__(onnx_model.graph)
self.onnx_model = onnx_model
self.initializer = {}
self.place_holder_nodes = list()
self.get_place_holder_nodes()
self.value_infos = self.inferred_model_value_info(self.model)
self.results_of_inference = dict()
def get_inner_nodes(self):
"""
generate inner node of ONNX model
"""
inner_nodes = []
if not isinstance(self.model, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')