Skip to content

Commit a818733

Browse files
drasmusstbekolay
authored andcommitted
Swap order of LMU states
Having the memory state come first is more intuitive, as it is both always present and comes first in the computational flow.
1 parent 41fdc58 commit a818733

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

.nengobones.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,6 @@ pyproject_toml: {}
108108
version_py:
109109
type: semver
110110
major: 0
111-
minor: 5
112-
patch: 1
111+
minor: 6
112+
patch: 0
113113
release: false

CHANGES.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Release history
1919
- Removed
2020
- Fixed
2121
22-
0.5.1 (unreleased)
22+
0.6.0 (unreleased)
2323
==================
2424

2525
*Compatible with TensorFlow 2.4 - 2.11*
@@ -31,6 +31,8 @@ Release history
3131
are met, as before). (`#52`_)
3232
- Allow ``input_to_hidden=True`` with ``hidden_cell=None``. This will act as a skip
3333
connection. (`#52`_)
34+
- Changed order of LMU states so that the LMU memory state always comes first, and
35+
any states from the hidden cell come afterwards. (`#52`_)
3436

3537
**Fixed**
3638

keras_lmu/layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def __init__(
171171
self.hidden_output_size = self.hidden_cell.units
172172
self.hidden_state_size = [self.hidden_cell.units]
173173

174-
self.state_size = tf.nest.flatten(self.hidden_state_size) + [
175-
self.memory_d * self.order
176-
]
174+
self.state_size = [self.memory_d * self.order] + tf.nest.flatten(
175+
self.hidden_state_size
176+
)
177177
self.output_size = self.hidden_output_size
178178

179179
@property
@@ -329,10 +329,10 @@ def call(self, inputs, states, training=None): # noqa: C901
329329

330330
states = tf.nest.flatten(states)
331331

332-
# state for the hidden cell
333-
h = states[:-1]
334332
# state for the LMU memory
335-
m = states[-1]
333+
m = states[0]
334+
# state for the hidden cell
335+
h = states[1:]
336336

337337
# compute memory input
338338
u = (
@@ -403,7 +403,7 @@ def call(self, inputs, states, training=None): # noqa: C901
403403
o = self.hidden_cell(h_in, training=training)
404404
h = [o]
405405

406-
return o, h + [m]
406+
return o, [m] + h
407407

408408
def reset_dropout_mask(self):
409409
"""Reset dropout mask for memory and hidden components."""

keras_lmu/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
tagged with the version.
1212
"""
1313

14-
version_info = (0, 5, 1)
14+
version_info = (0, 6, 0)
1515

1616
name = "keras-lmu"
1717
dev = 0

0 commit comments

Comments
 (0)