Skip to content

Unable to monitor membrane potential of LIFNode() when wrapped in LinearRecurrentContainer() #631

@ArchitMukherjee

Description

@ArchitMukherjee

Issue type

  • Bug Report
  • Feature Request
  • Help wanted
  • Other

SpikingJelly version

0.0.0.0.2

Description

Hi, I am trying to implement RLIF Neurons at the last layer of an SNN. And, I also want to record the membrane potential and spiking activity of the network. But when I wrap the last layer of the SNN with LinearRecurrentContainer, the membrane potential monitor returns a list of None. The spike activity monitor returns the spikes record but the formatting is incorrect.

I have provided two code snippets to reproduce the error, one is the normal implementation and the other is the bug.

Code Snippet for normal implementation

import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron, monitor

net = nn.Sequential(
            nn.Linear(in_features=8, out_features=4),
            neuron.LIFNode(),

            torch.nn.Linear(in_features=4, out_features=2),
            neuron.LIFNode()
        )
functional.set_step_mode(net, step_mode='m')

spike_seq_monitor = monitor.OutputMonitor(net, neuron.LIFNode)
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.LIFNode)
for m in net.modules():
    if isinstance(m, neuron.LIFNode):
        m.store_v_seq = True
T = 4
B = 1
N = 8
x_seq = torch.rand([T, B, N]) + 1.0

with torch.no_grad():
    net(x_seq)

print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')

Output:

Image

Code Snippet for bug

import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron, monitor

net = nn.Sequential(
            nn.Linear(in_features=8, out_features=4),
            neuron.LIFNode(),

            layer.LinearRecurrentContainer(nn.Sequential(torch.nn.Linear(in_features=4, out_features=2),
            neuron.LIFNode()), in_features=4, out_features=2),
        )
functional.set_step_mode(net, step_mode='m')

spike_seq_monitor = monitor.OutputMonitor(net, neuron.LIFNode)
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.LIFNode)
for m in net.modules():
    if isinstance(m, neuron.LIFNode):
        m.store_v_seq = True
T = 4
B = 1
N = 8
x_seq = torch.rand([T, B, N]) + 1.0

with torch.no_grad():
    net(x_seq)

print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')

Output:

Image

Hope the code is enough to reproduce the issue. Pleas help me to solve this. Thank-you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions