-
Notifications
You must be signed in to change notification settings - Fork 13
Description
I'm sorry I don't understand how to pass the 'initial_state' to LSTM in Keras.
For example, I got the init_h and init_c, and I want to pass them to tf.keras.layers.CuDNNLSTM.
What should I do?
I just define the model, and pass them in 'call()'. I need some help. Thanks in advance!
`class LSTMModel(tf.keras.Model):
def init(self):
super(LSTMModel, self).init()
self.embedding = tf.keras.layers.Embedding(input_dim=NUM_WORDS,
output_dim=EMBEDDING_UNITS,
input_shape=(1,))
self.lstm = tf.keras.layers.CuDNNLSTM(units=LSTM_UNITS,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
self.fc1 = tf.keras.layers.Dense(units=DENSE1_UNITS)
self.fc2 = tf.keras.layers.Dense(units=DENSE2_UNITS)
def call(self, inputs, context, initial_state=None, training=None, mask=None):
embedding = self.embedding(inputs)
context = tf.expand_dims(context, axis=1)
x = tf.concat([embedding, context], axis=-1)
y, state_h, state_c = self.lstm(x)
y = self.fc1(y)
y = tf.reshape(y, (-1, y.shape[2]))
y = self.fc2(y)
return y, state_h, state_c
y, h, c = lstm(dec_input, z, initial_state=[init_h, init_c])
`