Skip to content

IndRNNs #7

@OverLordGoldDragon

Description

@OverLordGoldDragon

I'd like to suggest support for IndRNNs; in my experiments on EEG seizure classification w/ very long sequences, they've dominated LSTMs & GRUs. While already also much faster, IndRNNs would benefit from a CuDNN-like speedup in large stacks, and from Layer Normalization for working w/ 1000+ timesteps.

Minimal tf.keras code below; default weight initialization should be handled differently - can clarify post-approval.


IndRNN Cell
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin



@keras_export(v1=['keras.layers.IndRNNCell'])
class IndRNNCell(DropoutRNNCellMixin, Layer):
  def __init__(self,
               units,
               activation='tanh',
               use_bias=True,
               recurrent_clip_min=-1,
               recurrent_clip_max=-1,
               kernel_initializer='glorot_normal',
               recurrent_initializer=None,
               bias_initializer='zeros',
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               implementation=1,
               **kwargs):
    super(IndRNNCell, self).__init__(**kwargs)
    
    if recurrent_clip_min is None or recurrent_clip_max is None:
        recurrent_clip_min = None
        recurrent_clip_max = None

    self.units = units
    self.activation = activations.get(activation)
    self.use_bias = use_bias
    self.recurrent_clip_min = recurrent_clip_min
    self.recurrent_clip_max = recurrent_clip_max

    self.kernel_initializer = initializers.get(kernel_initializer)
    if self.recurrent_initializer is None:
        self.recurrent_initializer = initializers.uniform(-1.0, 1.0)
    else:
        self.recurrent_initializer = initializers.get(recurrent_initializer)
    self.bias_initializer = initializers.get(bias_initializer)

    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
    self.bias_regularizer = regularizers.get(bias_regularizer)

    self.kernel_constraint = constraints.get(kernel_constraint)
    self.recurrent_constraint = constraints.get(recurrent_constraint)
    self.bias_constraint = constraints.get(bias_constraint)

    self.dropout = min(1., max(0., dropout))
    self.recurrent_dropout = min(1., max(0., recurrent_dropout))

    self.state_size = data_structures.NoDependency([self.units])
    self.output_size = self.units

  @tf_utils.shape_type_conversion
  def build(self, input_shape):
    input_dim = input_shape[-1]
    self.timesteps = input_shape[1]
    self._process_recurrent_clip()

    self.kernel = self.add_weight(
        shape=(input_dim, self.units),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units,),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint)

    if self.use_bias:
      self.bias = self.add_weight(
          shape=(self.units,),
          name='bias',
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint)
    else:
      self.bias = None    
    self.built = True


  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state

    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=1)
    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
        h_tm1, training, count=1)

    if 0. < self.dropout < 1.:
      inputs = inputs * dp_mask[0]
    if 0. < self.recurrent_dropout < 1.:
      h_tm1 = h_tm1 * rec_dp_mask[0]
    
    h = K.dot(inputs, self.kernel)
    h += math_ops.multiply(h_tm1, self.recurrent_kernel)
    if self.use_bias:
      h = K.bias_add(h, self.bias)

    h = self.activation(h)
    return h, [h]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions