|
26 | 26 | from PIL import Image
|
27 | 27 | import time
|
28 | 28 | import ilit
|
| 29 | +from ilit.adaptor.tf_utils.util import _parse_ckpt_bn_input |
29 | 30 |
|
30 | 31 | flags = tf.flags
|
31 | 32 | flags.DEFINE_string('style_images_paths', None, 'Paths to the style images'
|
@@ -77,39 +78,6 @@ def image_style_transfer(sess, content_img_path, style_img_path):
|
77 | 78 | # saves stylized image.
|
78 | 79 | save_image(stylized_image_res, os.path.join(FLAGS.output_dir, 'stylized_image.jpg'))
|
79 | 80 |
|
80 |
| -def _parse_ckpt_bn_input(graph_def): |
81 |
| - """parse ckpt batch norm inputs to match correct moving mean and variance |
82 |
| - Args: |
83 |
| - graph_def (graph_def): original graph_def |
84 |
| - Returns: |
85 |
| - graph_def: well linked graph_def |
86 |
| - """ |
87 |
| - for node in graph_def.node: |
88 |
| - if node.op == 'FusedBatchNorm': |
89 |
| - moving_mean_op_name = node.input[3] |
90 |
| - moving_var_op_name = node.input[4] |
91 |
| - moving_mean_op = _get_nodes_from_name(moving_mean_op_name, graph_def)[0] |
92 |
| - moving_var_op = _get_nodes_from_name(moving_var_op_name, graph_def)[0] |
93 |
| - |
94 |
| - if moving_mean_op.op == 'Const': |
95 |
| - name_part = moving_mean_op_name.rsplit('/', 1)[0] |
96 |
| - real_moving_mean_op_name = name_part + '/moving_mean' |
97 |
| - if len(_get_nodes_from_name(real_moving_mean_op_name, graph_def)) > 0: |
98 |
| - # replace the real moving mean op name |
99 |
| - node.input[3] = real_moving_mean_op_name |
100 |
| - |
101 |
| - if moving_var_op.op == 'Const': |
102 |
| - name_part = moving_var_op_name.rsplit('/', 1)[0] |
103 |
| - real_moving_var_op_name = name_part + '/moving_variance' |
104 |
| - if len(_get_nodes_from_name(real_moving_var_op_name, graph_def)) > 0: |
105 |
| - # replace the real moving mean op name |
106 |
| - node.input[4] = real_moving_var_op_name |
107 |
| - |
108 |
| - return graph_def |
109 |
| - |
110 |
| -def _get_nodes_from_name(node_name, graph_def): |
111 |
| - return [node for node in graph_def.node if node.name==node_name] |
112 |
| - |
113 | 81 | def main(args=None):
|
114 | 82 | tf.logging.set_verbosity(tf.logging.INFO)
|
115 | 83 | if not tf.gfile.Exists(FLAGS.output_dir):
|
|
0 commit comments