Skip to content

Commit c468259

Browse files
committed
[green] remove _parse_bn_input in style_tune, use ilit internal API
1 parent e3f9389 commit c468259

File tree

1 file changed

+1
-33
lines changed

1 file changed

+1
-33
lines changed

examples/tensorflow/style_transfer/style_tune.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from PIL import Image
2727
import time
2828
import ilit
29+
from ilit.adaptor.tf_utils.util import _parse_ckpt_bn_input
2930

3031
flags = tf.flags
3132
flags.DEFINE_string('style_images_paths', None, 'Paths to the style images'
@@ -77,39 +78,6 @@ def image_style_transfer(sess, content_img_path, style_img_path):
7778
# saves stylized image.
7879
save_image(stylized_image_res, os.path.join(FLAGS.output_dir, 'stylized_image.jpg'))
7980

80-
def _parse_ckpt_bn_input(graph_def):
81-
"""parse ckpt batch norm inputs to match correct moving mean and variance
82-
Args:
83-
graph_def (graph_def): original graph_def
84-
Returns:
85-
graph_def: well linked graph_def
86-
"""
87-
for node in graph_def.node:
88-
if node.op == 'FusedBatchNorm':
89-
moving_mean_op_name = node.input[3]
90-
moving_var_op_name = node.input[4]
91-
moving_mean_op = _get_nodes_from_name(moving_mean_op_name, graph_def)[0]
92-
moving_var_op = _get_nodes_from_name(moving_var_op_name, graph_def)[0]
93-
94-
if moving_mean_op.op == 'Const':
95-
name_part = moving_mean_op_name.rsplit('/', 1)[0]
96-
real_moving_mean_op_name = name_part + '/moving_mean'
97-
if len(_get_nodes_from_name(real_moving_mean_op_name, graph_def)) > 0:
98-
# replace the real moving mean op name
99-
node.input[3] = real_moving_mean_op_name
100-
101-
if moving_var_op.op == 'Const':
102-
name_part = moving_var_op_name.rsplit('/', 1)[0]
103-
real_moving_var_op_name = name_part + '/moving_variance'
104-
if len(_get_nodes_from_name(real_moving_var_op_name, graph_def)) > 0:
105-
# replace the real moving mean op name
106-
node.input[4] = real_moving_var_op_name
107-
108-
return graph_def
109-
110-
def _get_nodes_from_name(node_name, graph_def):
111-
return [node for node in graph_def.node if node.name==node_name]
112-
11381
def main(args=None):
11482
tf.logging.set_verbosity(tf.logging.INFO)
11583
if not tf.gfile.Exists(FLAGS.output_dir):

0 commit comments

Comments
 (0)