Skip to content

if filter width =3, how to do fast inference? #19

@weixsong

Description

@weixsong

In the new paper, Google use filter width =3 to increase the receptive field.

Then how could we do inference with filter width 3?
My idea is use to Queue, because the dilation is still 2 times increased, the first Queue is used to store the first half of middle value, and the second Queue is used to store the second half middle value.
Output of first Queue then be enqueued into the second Queue.

such as:

        current_state = q.dequeue()
        push = q.enqueue([current_layer])
        init_ops.append(init)
        push_ops.append(push)

        pre_state = None
        if self.filter_width == 3:
            q2 = tf.FIFOQueue(
                 1,
                 dtypes=tf.float32,
                 shapes=(self.batch_size, self.quantization_channels))

            init2 = q2.enqueue_many(tf.zeros((1, self.batch_size, self.quantization_channels)))

            pre_state = q2.dequeue()
            push2 = q2.enqueue([current_state])

            init_ops2.append(init2)
            push_ops2.append(push2)

        if self.filter_width == 2:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state)
        if self.filter_width == 3:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state, pre_state)
 


...
        with tf.name_scope('dilated_stack'):
            for layer_index, dilation in enumerate(self.dilations):
                with tf.name_scope('layer{}'.format(layer_index)):

                    q = tf.FIFOQueue(
                        dilation,
                        dtypes=tf.float32,
                        shapes=(self.batch_size, self.residual_channels))
                    init = q.enqueue_many(
                        tf.zeros((dilation, self.batch_size,
                                  self.residual_channels)))

                    current_state = q.dequeue()
                    push = q.enqueue([current_layer])
                    init_ops.append(init)
                    push_ops.append(push)

                    pre_state = None
                    if self.filter_width == 3:
                        q2 = tf.FIFOQueue(
                             dilation,
                             dtypes=tf.float32,
                             shapes=(self.batch_size, self.residual_channels))

                        init2 = q2.enqueue_many(tf.zeros((dilation, self.batch_size, self.residual_channels)))

                        pre_state = q2.dequeue()
                        push2 = q2.enqueue([current_state])

                        init_ops2.append(init2)
                        push_ops2.append(push2)

                    output, current_layer = self._generator_dilation_layer(
                        current_layer, current_state, layer_index, dilation,
                        global_condition_batch, local_condition, pre_state)
                    outputs.append(output)

is that make sense?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions