diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile
index ca1428de..20cf559a 100644
--- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile
+++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile
@@ -1,6 +1,6 @@
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04
-RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential
+RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential pkg-config libhdf5-serial-dev
WORKDIR /app
@@ -12,11 +12,13 @@ COPY . .
# 3. install dependencies ¯\_(ツ)_/¯
RUN pip3 install git+https://github.com/fangwei123456/spikingjelly.git
RUN pip3 install poetry
+
RUN git clone https://github.com/lava-nc/lava-dl.git && \
cd lava-dl && \
poetry config virtualenvs.in-project true && \
- poetry install && \
- source /app/lava-dl/.venv/bin/activate && \
+ poetry install
+
+RUN . /app/lava-dl/.venv/bin/activate && \
pip3 install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip3 install torch torchvision torchaudio && \
pip3 install -r ./requirements.txt
diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py
index 7558833b..0eeba4d0 100644
--- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py
+++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py
@@ -5,6 +5,214 @@
import os
+def rockpool_jax_lif_jit_cpu():
+ import rockpool
+ from rockpool.nn.modules import LIFJax, LinearJax
+ from rockpool.nn.combinators import Sequential
+ import numpy as np
+ import jax
+
+ def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
+ model = Sequential(
+ LinearJax(shape=(n_neurons, n_neurons)),
+ LIFJax(n_neurons),
+ )
+ input_static = jax.numpy.array(np.random.rand(batch_size, n_steps, n_neurons))
+
+ def apply(mod, input):
+ output = mod(input)
+ return output
+
+ apply = jax.jit(apply, backend="cpu")
+
+ def loss(mod, input):
+ out, _, _ = mod(input)
+ return out.sum()
+
+ grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="cpu")
+
+ # - Force compilation
+ apply(model, input_static)
+ grad_loss(model, input_static)
+
+ return dict(
+ model=model,
+ jit_fwd=apply,
+ jit_bwd=grad_loss,
+ n_neurons=n_neurons,
+ input=input_static,
+ )
+
+ def forward_fn(bench_dict):
+ model, apply, input = (
+ bench_dict["model"],
+ bench_dict["jit_fwd"],
+ bench_dict["input"],
+ )
+ bench_dict["output"] = apply(model, input)[0]
+ return bench_dict
+
+ def backward_fn(bench_dict):
+ model, loss, input = (
+ bench_dict["model"],
+ bench_dict["jit_bwd"],
+ bench_dict["input"],
+ )
+ loss(model, input)
+
+ benchmark_title = f"Rockpool jax CPU accel.
v{rockpool.__version__}"
+
+ return prepare_fn, forward_fn, backward_fn, benchmark_title
+
+
+def rockpool_jax_lif_jit_gpu():
+ import rockpool
+ from rockpool.nn.modules import LIFJax, LinearJax
+ from rockpool.nn.combinators import Sequential
+ import numpy as np
+ import jax
+
+ def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
+ model = Sequential(
+ LinearJax(shape=(n_neurons, n_neurons)),
+ LIFJax(n_neurons),
+ )
+ input_static = np.random.rand(batch_size, n_steps, n_neurons)
+
+ def apply(mod, input):
+ return mod(input)
+
+ apply = jax.jit(apply, backend="gpu")
+
+ def loss(mod, input):
+ out, _, _ = mod(input)
+ return out.sum()
+
+ grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="gpu")
+
+ # - Force compilation
+ apply(model, input_static)
+ grad_loss(model, input_static)
+
+ return dict(
+ model=model,
+ jit_fwd=apply,
+ jit_bwd=grad_loss,
+ n_neurons=n_neurons,
+ input=input_static,
+ )
+
+ def forward_fn(bench_dict):
+ model, apply, input = (
+ bench_dict["model"],
+ bench_dict["jit_fwd"],
+ bench_dict["input"],
+ )
+ bench_dict["output"] = apply(model, input)[0]
+ return bench_dict
+
+ def backward_fn(bench_dict):
+ model, loss, input = (
+ bench_dict["model"],
+ bench_dict["jit_bwd"],
+ bench_dict["input"],
+ )
+ loss(model, input)
+
+ benchmark_title = f"Rockpool jax GPU accel.
v{rockpool.__version__}"
+
+ return prepare_fn, forward_fn, backward_fn, benchmark_title
+
+
+def rockpool_jax_lif_jit_tpu():
+ import rockpool
+ from rockpool.nn.modules import LIFJax, LinearJax
+ from rockpool.nn.combinators import Sequential
+ import numpy as np
+ import jax
+
+ def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
+ model = Sequential(
+ LinearJax(shape=(n_neurons, n_neurons)),
+ LIFJax(n_neurons),
+ )
+ input_static = np.random.rand(batch_size, n_steps, n_neurons)
+
+ def apply(mod, input):
+ return mod(input)
+
+ apply = jax.jit(apply, backend="tpu")
+
+ def loss(mod, input):
+ out, _, _ = mod(input)
+ return out.sum()
+
+ grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="tpu")
+
+ # - Force compilation
+ apply(model, input_static)
+ grad_loss(model, input_static)
+
+ return dict(
+ model=model,
+ jit_fwd=apply,
+ jit_bwd=grad_loss,
+ n_neurons=n_neurons,
+ input=input_static,
+ )
+
+ def forward_fn(bench_dict):
+ model, apply, input = (
+ bench_dict["model"],
+ bench_dict["jit_fwd"],
+ bench_dict["input"],
+ )
+ bench_dict["output"] = apply(model, input)[0]
+ return bench_dict
+
+ def backward_fn(bench_dict):
+ model, loss, input = (
+ bench_dict["model"],
+ bench_dict["jit_bwd"],
+ bench_dict["input"],
+ )
+ loss(model, input)
+
+ benchmark_title = f"Rockpool jax TPU accel.
v{rockpool.__version__}"
+
+ return prepare_fn, forward_fn, backward_fn, benchmark_title
+
+
+def rockpool_native():
+ import rockpool
+ from rockpool.nn.modules import LIF, Linear
+ from rockpool.nn.combinators import Sequential
+ import numpy as np
+
+ def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
+ model = Sequential(
+ Linear(shape=(n_neurons, n_neurons)),
+ LIF(n_neurons),
+ )
+ input_static = np.random.rand(batch_size, n_steps, n_neurons)
+
+ model(input_static)
+
+ return dict(model=model, n_neurons=n_neurons, input=input_static)
+
+ def forward_fn(bench_dict):
+ mod, input_static = bench_dict["model"], bench_dict["input"]
+ bench_dict["output"] = mod(input_static)[0]
+ return bench_dict
+
+ def backward_fn(bench_dict):
+ pass
+
+ benchmark_title = f"Rockpool numpy
v{rockpool.__version__}"
+
+ return prepare_fn, forward_fn, backward_fn, benchmark_title
+
+
def rockpool_torch():
import torch
@@ -13,7 +221,7 @@ def rockpool_torch():
from rockpool.nn.combinators import Sequential
import rockpool
- benchmark_title = f"Rockpool
v{rockpool.__version__}"
+ benchmark_title = f"Rockpool torch
v{rockpool.__version__}"
def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
@@ -73,6 +281,83 @@ def backward_fn(bench_dict):
return prepare_fn, forward_fn, backward_fn, benchmark_title
+def rockpool_torch_cuda_graph():
+ from rockpool.nn.modules import LIFTorch
+ import torch
+
+ class StepPWL(torch.autograd.Function):
+ """
+ Heaviside step function with piece-wise linear surrogate to use as spike-generation surrogate
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ threshold=torch.tensor(1.0),
+ window=torch.tensor(0.5),
+ max_spikes_per_dt=torch.tensor(2.0**16),
+ ):
+ ctx.save_for_backward(x, threshold)
+ ctx.window = window
+ nr_spikes = ((x >= threshold) * torch.floor(x / threshold)).float()
+ # nr_spikes[nr_spikes > max_spikes_per_dt] = max_spikes_per_dt.float()
+ clamp_bool = (nr_spikes > max_spikes_per_dt).float()
+ nr_spikes -= (nr_spikes - max_spikes_per_dt.float()) * clamp_bool
+ return nr_spikes
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, threshold = ctx.saved_tensors
+ grad_x = grad_threshold = grad_window = grad_max_spikes_per_dt = None
+
+ mask = x >= (threshold - ctx.window)
+ if ctx.needs_input_grad[0]:
+ grad_x = grad_output / threshold * mask
+
+ if ctx.needs_input_grad[1]:
+ grad_threshold = -x * grad_output / (threshold**2) * mask
+
+ return grad_x, grad_threshold, grad_window, grad_max_spikes_per_dt
+
+ def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
+ mod = LIFTorch(
+ n_neurons, spike_generation_fn=StepPWL, max_spikes_per_dt=2.0**16
+ ).cuda()
+ input_static = torch.randn(batch_size, n_steps, n_neurons, device=device)
+
+ # - Warm up the CUDA stream
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for i in range(3):
+ y_pred, _, _ = mod(input_static)
+
+ # - Capture the graph
+ g = torch.cuda.CUDAGraph()
+
+ with torch.cuda.graph(g):
+ static_y_pred, _, _ = mod(input_static)
+
+ return dict(
+ model=g, input=input_static, n_neurons=n_neurons, output=static_y_pred
+ )
+
+ def forward_fn(bench_dict):
+ model = bench_dict["model"]
+ model.replay()
+ return bench_dict
+
+ def backward_fn(bench_dict):
+ output = bench_dict["output"]
+ loss = output.sum()
+ loss.backward(retain_graph=True)
+
+ benchmark_title = f"LIFTorch using CUDA graph replay acceleration"
+
+ return prepare_fn, forward_fn, backward_fn, benchmark_title
+
+
def sinabs():
import torch
from torch import nn
@@ -493,6 +778,11 @@ def backward_fn(bench_dict):
benchmarks = {
"rockpool_torch": rockpool_torch,
"rockpool_exodus": rockpool_exodus,
+ "rockpool_native": rockpool_native,
+ "rockpool_jax_cpu": rockpool_jax_lif_jit_cpu,
+ "rockpool_jax_gpu": rockpool_jax_lif_jit_gpu,
+ "rockpool_jax_tpu": rockpool_jax_lif_jit_tpu,
+ "rockpool_torch_cuda_graph": rockpool_torch_cuda_graph,
"sinabs": sinabs,
"sinabs_exodus": sinabs_exodus,
"norse": norse,
diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh
index 94ce8882..d0766fa5 100755
--- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh
+++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh
@@ -6,7 +6,7 @@ echo "running benchmarks with batch size $1"
echo "framework,neurons,forward,backward,memory" >data.csv
-for benchmark in "rockpool_torch" "rockpool_exodus" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do
+for benchmark in "rockpool_jax_cpu" "rockpool_native" "rockpool_torch" "rockpool_exodus" "rockpool_jax_tpu" "rockpool_torch_cuda_graph" "rockpool_jax_gpu" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do
python3 ./run_benchmark.py $benchmark $1
done
diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py
index 17555fc6..87f9b12f 100644
--- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py
+++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py
@@ -4,8 +4,9 @@
from time import time
+
def log_result(framework, neurons, forward, backward, memory):
- with open('data.csv', 'a') as f:
+ with open("data.csv", "a") as f:
w = writer(f)
w.writerow([framework, neurons, forward, backward, memory])
@@ -57,26 +58,28 @@ def benchmark_framework(
try:
# - Prepare benchmark
bench_dict = prepare_fn(
- batch_size=batch_size,
- n_steps=n_steps,
+ batch_size=batch_size,
+ n_steps=n_steps,
n_neurons=n_neurons,
- n_layers=n_layers,
- device=device
+ n_layers=n_layers,
+ device=device,
)
# - Forward pass
forward_times.append(timeit(lambda: forward_fn(bench_dict)))
bench_dict = forward_fn(bench_dict)
- assert bench_dict["output"].shape == bench_dict["input"].shape
+ shape_out = bench_dict["output"].shape
+ shape_in = bench_dict["input"].shape
+ assert (
+ shape_out == shape_in
+ ), f"Output shape {shape_out} != input shape {shape_in}"
# - Backward pass
backward_times.append(timeit(lambda: backward_fn(bench_dict)))
except Exception as e:
# - Fail nicely with a warning if a benchmark dies
- warnings.warn(
- f"Benchmark {benchmark_desc} failed with error {str(e)}."
- )
+ warnings.warn(f"Benchmark {benchmark_desc} failed with error {str(e)}.")
# - No results for this run
forward_times.append([])