-
Notifications
You must be signed in to change notification settings - Fork 307
Open
Description
In the new paper, Google use filter width =3 to increase the receptive field.
Then how could we do inference with filter width 3?
My idea is use to Queue, because the dilation is still 2 times increased, the first Queue is used to store the first half of middle value, and the second Queue is used to store the second half middle value.
Output of first Queue then be enqueued into the second Queue.
such as:
current_state = q.dequeue()
push = q.enqueue([current_layer])
init_ops.append(init)
push_ops.append(push)
pre_state = None
if self.filter_width == 3:
q2 = tf.FIFOQueue(
1,
dtypes=tf.float32,
shapes=(self.batch_size, self.quantization_channels))
init2 = q2.enqueue_many(tf.zeros((1, self.batch_size, self.quantization_channels)))
pre_state = q2.dequeue()
push2 = q2.enqueue([current_state])
init_ops2.append(init2)
push_ops2.append(push2)
if self.filter_width == 2:
current_layer = self._generator_causal_layer(
current_layer, current_state)
if self.filter_width == 3:
current_layer = self._generator_causal_layer(
current_layer, current_state, pre_state)
...
with tf.name_scope('dilated_stack'):
for layer_index, dilation in enumerate(self.dilations):
with tf.name_scope('layer{}'.format(layer_index)):
q = tf.FIFOQueue(
dilation,
dtypes=tf.float32,
shapes=(self.batch_size, self.residual_channels))
init = q.enqueue_many(
tf.zeros((dilation, self.batch_size,
self.residual_channels)))
current_state = q.dequeue()
push = q.enqueue([current_layer])
init_ops.append(init)
push_ops.append(push)
pre_state = None
if self.filter_width == 3:
q2 = tf.FIFOQueue(
dilation,
dtypes=tf.float32,
shapes=(self.batch_size, self.residual_channels))
init2 = q2.enqueue_many(tf.zeros((dilation, self.batch_size, self.residual_channels)))
pre_state = q2.dequeue()
push2 = q2.enqueue([current_state])
init_ops2.append(init2)
push_ops2.append(push2)
output, current_layer = self._generator_dilation_layer(
current_layer, current_state, layer_index, dilation,
global_condition_batch, local_condition, pre_state)
outputs.append(output)
is that make sense?
Metadata
Metadata
Assignees
Labels
No labels