-
Notifications
You must be signed in to change notification settings - Fork 283
Open
Description
Issue type
- Bug Report
- Feature Request
- Help wanted
- Other
SpikingJelly version
0.0.0.0.2
Description
Hi, I am trying to implement RLIF Neurons at the last layer of an SNN. And, I also want to record the membrane potential and spiking activity of the network. But when I wrap the last layer of the SNN with LinearRecurrentContainer, the membrane potential monitor returns a list of None
. The spike activity monitor returns the spikes record but the formatting is incorrect.
I have provided two code snippets to reproduce the error, one is the normal implementation and the other is the bug.
Code Snippet for normal implementation
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron, monitor
net = nn.Sequential(
nn.Linear(in_features=8, out_features=4),
neuron.LIFNode(),
torch.nn.Linear(in_features=4, out_features=2),
neuron.LIFNode()
)
functional.set_step_mode(net, step_mode='m')
spike_seq_monitor = monitor.OutputMonitor(net, neuron.LIFNode)
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.LIFNode)
for m in net.modules():
if isinstance(m, neuron.LIFNode):
m.store_v_seq = True
T = 4
B = 1
N = 8
x_seq = torch.rand([T, B, N]) + 1.0
with torch.no_grad():
net(x_seq)
print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
Output:

Code Snippet for bug
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron, monitor
net = nn.Sequential(
nn.Linear(in_features=8, out_features=4),
neuron.LIFNode(),
layer.LinearRecurrentContainer(nn.Sequential(torch.nn.Linear(in_features=4, out_features=2),
neuron.LIFNode()), in_features=4, out_features=2),
)
functional.set_step_mode(net, step_mode='m')
spike_seq_monitor = monitor.OutputMonitor(net, neuron.LIFNode)
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.LIFNode)
for m in net.modules():
if isinstance(m, neuron.LIFNode):
m.store_v_seq = True
T = 4
B = 1
N = 8
x_seq = torch.rand([T, B, N]) + 1.0
with torch.no_grad():
net(x_seq)
print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
Output:

Hope the code is enough to reproduce the issue. Pleas help me to solve this. Thank-you!
Metadata
Metadata
Assignees
Labels
No labels