|
| 1 | +# This file is a part of OpenCV project. |
| 2 | +# It is a subject to the license terms in the LICENSE file found in the top-level directory |
| 3 | +# of this distribution and at http://opencv.org/license.html. |
| 4 | +# |
| 5 | +# Copyright (C) 2020, Intel Corporation, all rights reserved. |
| 6 | +# Third party copyrights are property of their respective owners. |
| 7 | +# |
| 8 | +# Use this script to get the text graph representation (.pbtxt) of EfficientDet |
| 9 | +# deep learning network trained in https://github.com/google/automl. |
| 10 | +# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function. |
| 11 | +# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API |
| 12 | +import argparse |
| 13 | +import re |
| 14 | +from math import sqrt |
| 15 | +from tf_text_graph_common import * |
| 16 | + |
| 17 | + |
| 18 | +class AnchorGenerator: |
| 19 | + def __init__(self, min_level, aspect_ratios, num_scales, anchor_scale): |
| 20 | + self.min_level = min_level |
| 21 | + self.aspect_ratios = aspect_ratios |
| 22 | + self.anchor_scale = anchor_scale |
| 23 | + self.scales = [2**(float(s) / num_scales) for s in range(num_scales)] |
| 24 | + |
| 25 | + def get(self, layer_id): |
| 26 | + widths = [] |
| 27 | + heights = [] |
| 28 | + for s in self.scales: |
| 29 | + for a in self.aspect_ratios: |
| 30 | + base_anchor_size = 2**(self.min_level + layer_id) * self.anchor_scale |
| 31 | + heights.append(base_anchor_size * s * a[1]) |
| 32 | + widths.append(base_anchor_size * s * a[0]) |
| 33 | + return widths, heights |
| 34 | + |
| 35 | + |
| 36 | +def createGraph(modelPath, outputPath, min_level, aspect_ratios, num_scales, |
| 37 | + anchor_scale, num_classes, image_width, image_height): |
| 38 | + print('Min level: %d' % min_level) |
| 39 | + print('Anchor scale: %f' % anchor_scale) |
| 40 | + print('Num scales: %d' % num_scales) |
| 41 | + print('Aspect ratios: %s' % str(aspect_ratios)) |
| 42 | + print('Number of classes: %d' % num_classes) |
| 43 | + print('Input image size: %dx%d' % (image_width, image_height)) |
| 44 | + |
| 45 | + # Read the graph. |
| 46 | + _inpNames = ['image_arrays'] |
| 47 | + outNames = ['detections'] |
| 48 | + |
| 49 | + writeTextGraph(modelPath, outputPath, outNames) |
| 50 | + graph_def = parseTextGraph(outputPath) |
| 51 | + |
| 52 | + def getUnconnectedNodes(): |
| 53 | + unconnected = [] |
| 54 | + for node in graph_def.node: |
| 55 | + if node.op == 'Const': |
| 56 | + continue |
| 57 | + unconnected.append(node.name) |
| 58 | + for inp in node.input: |
| 59 | + if inp in unconnected: |
| 60 | + unconnected.remove(inp) |
| 61 | + return unconnected |
| 62 | + |
| 63 | + |
| 64 | + nodesToKeep = ['truediv'] # Keep preprocessing nodes |
| 65 | + |
| 66 | + removeIdentity(graph_def) |
| 67 | + |
| 68 | + scopesToKeep = ('image_arrays', 'efficientnet', 'resample_p6', 'resample_p7', |
| 69 | + 'fpn_cells', 'class_net', 'box_net', 'Reshape', 'concat') |
| 70 | + |
| 71 | + addConstNode('scale_w', [2.0], graph_def) |
| 72 | + addConstNode('scale_h', [2.0], graph_def) |
| 73 | + nodesToKeep += ['scale_w', 'scale_h'] |
| 74 | + |
| 75 | + for node in graph_def.node: |
| 76 | + if re.match('efficientnet-(.*)/blocks_\d+/se/mul_1', node.name): |
| 77 | + node.input[0], node.input[1] = node.input[1], node.input[0] |
| 78 | + |
| 79 | + if re.match('fpn_cells/cell_\d+/fnode\d+/resample(.*)/nearest_upsampling/Reshape_1$', node.name): |
| 80 | + node.op = 'ResizeNearestNeighbor' |
| 81 | + node.input[1] = 'scale_w' |
| 82 | + node.input.append('scale_h') |
| 83 | + |
| 84 | + for inpNode in graph_def.node: |
| 85 | + if inpNode.name == node.name[:node.name.rfind('_')]: |
| 86 | + node.input[0] = inpNode.input[0] |
| 87 | + |
| 88 | + if re.match('box_net/box-predict(_\d)*/separable_conv2d$', node.name): |
| 89 | + node.addAttr('loc_pred_transposed', True) |
| 90 | + |
| 91 | + # Replace RealDiv to Mul with inversed scale for compatibility |
| 92 | + if node.op == 'RealDiv': |
| 93 | + for inpNode in graph_def.node: |
| 94 | + if inpNode.name != node.input[1] or not 'value' in inpNode.attr: |
| 95 | + continue |
| 96 | + |
| 97 | + tensor = inpNode.attr['value']['tensor'][0] |
| 98 | + if not 'float_val' in tensor: |
| 99 | + continue |
| 100 | + scale = float(inpNode.attr['value']['tensor'][0]['float_val'][0]) |
| 101 | + |
| 102 | + addConstNode(inpNode.name + '/inv', [1.0 / scale], graph_def) |
| 103 | + nodesToKeep.append(inpNode.name + '/inv') |
| 104 | + node.input[1] = inpNode.name + '/inv' |
| 105 | + node.op = 'Mul' |
| 106 | + break |
| 107 | + |
| 108 | + |
| 109 | + def to_remove(name, op): |
| 110 | + if name in nodesToKeep: |
| 111 | + return False |
| 112 | + return op == 'Const' or not name.startswith(scopesToKeep) |
| 113 | + |
| 114 | + removeUnusedNodesAndAttrs(to_remove, graph_def) |
| 115 | + |
| 116 | + # Attach unconnected preprocessing |
| 117 | + assert(graph_def.node[1].name == 'truediv' and graph_def.node[1].op == 'RealDiv') |
| 118 | + graph_def.node[1].input.insert(0, 'image_arrays') |
| 119 | + graph_def.node[2].input.insert(0, 'truediv') |
| 120 | + |
| 121 | + priors_generator = AnchorGenerator(min_level, aspect_ratios, num_scales, anchor_scale) |
| 122 | + priorBoxes = [] |
| 123 | + for i in range(5): |
| 124 | + inpName = '' |
| 125 | + for node in graph_def.node: |
| 126 | + if node.name == 'Reshape_%d' % (i * 2 + 1): |
| 127 | + inpName = node.input[0] |
| 128 | + break |
| 129 | + |
| 130 | + priorBox = NodeDef() |
| 131 | + priorBox.name = 'PriorBox_%d' % i |
| 132 | + priorBox.op = 'PriorBox' |
| 133 | + priorBox.input.append(inpName) |
| 134 | + priorBox.input.append(graph_def.node[0].name) # image_tensor |
| 135 | + |
| 136 | + priorBox.addAttr('flip', False) |
| 137 | + priorBox.addAttr('clip', False) |
| 138 | + |
| 139 | + widths, heights = priors_generator.get(i) |
| 140 | + |
| 141 | + priorBox.addAttr('width', widths) |
| 142 | + priorBox.addAttr('height', heights) |
| 143 | + priorBox.addAttr('variance', [1.0, 1.0, 1.0, 1.0]) |
| 144 | + |
| 145 | + graph_def.node.extend([priorBox]) |
| 146 | + priorBoxes.append(priorBox.name) |
| 147 | + |
| 148 | + addConstNode('concat/axis_flatten', [-1], graph_def) |
| 149 | + |
| 150 | + def addConcatNode(name, inputs, axisNodeName): |
| 151 | + concat = NodeDef() |
| 152 | + concat.name = name |
| 153 | + concat.op = 'ConcatV2' |
| 154 | + for inp in inputs: |
| 155 | + concat.input.append(inp) |
| 156 | + concat.input.append(axisNodeName) |
| 157 | + graph_def.node.extend([concat]) |
| 158 | + |
| 159 | + addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten') |
| 160 | + |
| 161 | + sigmoid = NodeDef() |
| 162 | + sigmoid.name = 'concat/sigmoid' |
| 163 | + sigmoid.op = 'Sigmoid' |
| 164 | + sigmoid.input.append('concat') |
| 165 | + graph_def.node.extend([sigmoid]) |
| 166 | + |
| 167 | + addFlatten(sigmoid.name, sigmoid.name + '/Flatten', graph_def) |
| 168 | + addFlatten('concat_1', 'concat_1/Flatten', graph_def) |
| 169 | + |
| 170 | + detectionOut = NodeDef() |
| 171 | + detectionOut.name = 'detection_out' |
| 172 | + detectionOut.op = 'DetectionOutput' |
| 173 | + |
| 174 | + detectionOut.input.append('concat_1/Flatten') |
| 175 | + detectionOut.input.append(sigmoid.name + '/Flatten') |
| 176 | + detectionOut.input.append('PriorBox/concat') |
| 177 | + |
| 178 | + detectionOut.addAttr('num_classes', num_classes) |
| 179 | + detectionOut.addAttr('share_location', True) |
| 180 | + detectionOut.addAttr('background_label_id', num_classes + 1) |
| 181 | + detectionOut.addAttr('nms_threshold', 0.6) |
| 182 | + detectionOut.addAttr('confidence_threshold', 0.2) |
| 183 | + detectionOut.addAttr('top_k', 100) |
| 184 | + detectionOut.addAttr('keep_top_k', 100) |
| 185 | + detectionOut.addAttr('code_type', "CENTER_SIZE") |
| 186 | + graph_def.node.extend([detectionOut]) |
| 187 | + |
| 188 | + graph_def.node[0].attr['shape'] = { |
| 189 | + 'shape': { |
| 190 | + 'dim': [ |
| 191 | + {'size': -1}, |
| 192 | + {'size': image_height}, |
| 193 | + {'size': image_width}, |
| 194 | + {'size': 3} |
| 195 | + ] |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + while True: |
| 200 | + unconnectedNodes = getUnconnectedNodes() |
| 201 | + unconnectedNodes.remove(detectionOut.name) |
| 202 | + if not unconnectedNodes: |
| 203 | + break |
| 204 | + |
| 205 | + for name in unconnectedNodes: |
| 206 | + for i in range(len(graph_def.node)): |
| 207 | + if graph_def.node[i].name == name: |
| 208 | + del graph_def.node[i] |
| 209 | + break |
| 210 | + |
| 211 | + # Save as text |
| 212 | + graph_def.save(outputPath) |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' |
| 217 | + 'SSD model from TensorFlow Object Detection API. ' |
| 218 | + 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') |
| 219 | + parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') |
| 220 | + parser.add_argument('--output', required=True, help='Path to output text graph.') |
| 221 | + parser.add_argument('--min_level', default=3, type=int, help='Parameter from training config') |
| 222 | + parser.add_argument('--num_scales', default=3, type=int, help='Parameter from training config') |
| 223 | + parser.add_argument('--anchor_scale', default=4.0, type=float, help='Parameter from training config') |
| 224 | + parser.add_argument('--aspect_ratios', default=[1.0, 1.0, 1.4, 0.7, 0.7, 1.4], |
| 225 | + nargs='+', type=float, help='Parameter from training config') |
| 226 | + parser.add_argument('--num_classes', default=90, type=int, help='Number of classes to detect') |
| 227 | + parser.add_argument('--width', default=512, type=int, help='Network input width') |
| 228 | + parser.add_argument('--height', default=512, type=int, help='Network input height') |
| 229 | + args = parser.parse_args() |
| 230 | + |
| 231 | + ar = args.aspect_ratios |
| 232 | + assert(len(ar) % 2 == 0) |
| 233 | + ar = list(zip(ar[::2], ar[1::2])) |
| 234 | + |
| 235 | + createGraph(args.input, args.output, args.min_level, ar, args.num_scales, |
| 236 | + args.anchor_scale, args.num_classes, args.width, args.height) |
0 commit comments