Skip to content

Cannot use T > 1 with multi-step module and LIF and IF neurons say cupy not supported #632

@SM1991CODES

Description

@SM1991CODES

Read before creating a new issue

  • Users who want to use SpikingJelly should first be familiar with the usage of PyTorch.
  • If you do not know much about PyTorch, we recommend that the user can learn the basic tutorials of PyTorch.
  • Do not ask for help with the basic conception of PyTorch/Machine Learning but not related to SpikingJelly. For these questions, please refer to Google or PyTorch Forums.

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

  • Bug Report
  • Feature Request
  • Help wanted
  • Other

@fangwei123456
SpikingJelly version

0.0.0.0.2

Description
I get the following error when trying T > 1 with multi-step forward modules

WARNING:root:LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
) does not supports for backend=cupy. It will still use backend=torch.
WARNING:root:LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
) does not supports for backend=cupy. It will still use backend=torch.
Setting step mode to multi-step..
Traceback (most recent call last):
File "c:\Users\Sambit\Documents\PHD\codeworks\snnworks\root\sjelly_models\blocks.py", line 57, in
net(x)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "c:\Users\Sambit\Documents\PHD\codeworks\snnworks\root\sjelly_models\blocks.py", line 49, in forward
x_cam = self.act2_cam(self.conv2_cam(x_cam))
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\base.py", line 270, in forward
return self.multi_step_forward(*args, **kwargs)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\neuron.py", line 933, in multi_step_forward
return super().multi_step_forward(x_seq)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\neuron.py", line 250, in multi_step_forward
y = self.single_step_forward(x_seq[t])
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\neuron.py", line 907, in single_step_forward
return super().single_step_forward(x)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\neuron.py", line 241, in single_step_forward
self.neuronal_reset(spike)
File "C:\Users\Sambit\miniforge3\envs\snn-env\lib\site-packages\spikingjelly\activation_based\neuron.py", line 205, in neuronal_reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: nvrtc: error: failed to open nvrtc-builtins64_118.dll.
Make sure that nvrtc-builtins64_118.dll is installed correctly.
nvrtc compilation failed:

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

template
device T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}

template
device T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}

extern "C" global
void fused_neg_add_mul_mul_add(float* tspike_1, double vv_reset_2, float* tv_1, float* aten_add_1, float* aten_add) {
{
float tspike_1_1 = __ldg(tspike_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
aten_add[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = (0.f - tspike_1_1) + 1.f;
float v = __ldg(tv_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
aten_add_1[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = ((0.f - tspike_1_1) + 1.f) * v + tspike_1_1 * (float)(vv_reset_2);
}
}
...

Minimal code to reproduce the error/bug

import torch
import torch.nn as nn
import cupy as cp
from spikingjelly.activation_based import neuron, layer, surrogate, functional


filter_dict = {1: 32, 2: 64, 3: 128, 4: 256, 5: 512, 6: 256, 7: 128}  # number of filters to use in each level of blocks
N_CHANNELS_IN_HEADS = 128  # NOTE: increased to collect more low level features


class SPK_CAM(nn.Module):
    """
    Implementation of the Context Aggregation Module from Squeezeseg v2
    """

    def __init__(self, in_ch, v_th=0.8, v_reset=0.05, backend="cupy"):
        """
        Default constructor
        Args:
            in_ch (int): number of input channels to this CAM block
        """

        super(SPK_CAM, self).__init__()

        self.N_C_OUT = in_ch

        self.pool_cam = layer.AvgPool2d(kernel_size=(7, 7), padding=(3, 3), stride=(1, 1))
        self.conv1_cam = layer.Conv2d(self.N_C_OUT, self.N_C_OUT // 4, kernel_size=(1, 1))
        self.act1_cam = neuron.LIFNode()

        self.conv2_cam = layer.Conv2d(self.N_C_OUT // 4, self.N_C_OUT, kernel_size=(1, 1))
        self.act2_cam = neuron.LIFNode()

        if backend == "cupy":
            functional.set_backend(self, backend='cupy')
            functional.set_step_mode(self, 'm')
            print("Setting step mode to multi-step..")

    def forward(self, x):
        """
        Moves data through the network
        Args:
            x (tensor): input
        Returns:
        """

        x_cam = self.pool_cam(x)
        x_cam = self.act1_cam(self.conv1_cam(x_cam))
        x_cam = self.act2_cam(self.conv2_cam(x_cam))
        x = x * x_cam
        return x
    
if __name__ == "__main__":

    net = SPK_CAM(32).cuda()
    x = torch.randn((2, 2, 32, 416, 416), device="cuda:0")
    net(x)
    functional.reset_net(net)


# ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions