-
Notifications
You must be signed in to change notification settings - Fork 28
Open
Description
I'd like to suggest support for IndRNNs; in my experiments on EEG seizure classification w/ very long sequences, they've dominated LSTMs & GRUs. While already also much faster, IndRNNs would benefit from a CuDNN-like speedup in large stacks, and from Layer Normalization for working w/ 1000+ timesteps.
Minimal tf.keras
code below; default weight initialization should be handled differently - can clarify post-approval.
IndRNN Cell
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
@keras_export(v1=['keras.layers.IndRNNCell'])
class IndRNNCell(DropoutRNNCellMixin, Layer):
def __init__(self,
units,
activation='tanh',
use_bias=True,
recurrent_clip_min=-1,
recurrent_clip_max=-1,
kernel_initializer='glorot_normal',
recurrent_initializer=None,
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=1,
**kwargs):
super(IndRNNCell, self).__init__(**kwargs)
if recurrent_clip_min is None or recurrent_clip_max is None:
recurrent_clip_min = None
recurrent_clip_max = None
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.recurrent_clip_min = recurrent_clip_min
self.recurrent_clip_max = recurrent_clip_max
self.kernel_initializer = initializers.get(kernel_initializer)
if self.recurrent_initializer is None:
self.recurrent_initializer = initializers.uniform(-1.0, 1.0)
else:
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = data_structures.NoDependency([self.units])
self.output_size = self.units
@tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.timesteps = input_shape[1]
self._process_recurrent_clip()
self.kernel = self.add_weight(
shape=(input_dim, self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=(self.units,),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory state
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=1)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
h_tm1, training, count=1)
if 0. < self.dropout < 1.:
inputs = inputs * dp_mask[0]
if 0. < self.recurrent_dropout < 1.:
h_tm1 = h_tm1 * rec_dp_mask[0]
h = K.dot(inputs, self.kernel)
h += math_ops.multiply(h_tm1, self.recurrent_kernel)
if self.use_bias:
h = K.bias_add(h, self.bias)
h = self.activation(h)
return h, [h]
iceychris
Metadata
Metadata
Assignees
Labels
No labels