diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile index ca1428de..20cf559a 100644 --- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile +++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/Dockerfile @@ -1,6 +1,6 @@ FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 -RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential +RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential pkg-config libhdf5-serial-dev WORKDIR /app @@ -12,11 +12,13 @@ COPY . . # 3. install dependencies ¯\_(ツ)_/¯ RUN pip3 install git+https://github.com/fangwei123456/spikingjelly.git RUN pip3 install poetry + RUN git clone https://github.com/lava-nc/lava-dl.git && \ cd lava-dl && \ poetry config virtualenvs.in-project true && \ - poetry install && \ - source /app/lava-dl/.venv/bin/activate && \ + poetry install + +RUN . /app/lava-dl/.venv/bin/activate && \ pip3 install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ pip3 install torch torchvision torchaudio && \ pip3 install -r ./requirements.txt diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py index 7558833b..0eeba4d0 100644 --- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py +++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmark.py @@ -5,6 +5,214 @@ import os +def rockpool_jax_lif_jit_cpu(): + import rockpool + from rockpool.nn.modules import LIFJax, LinearJax + from rockpool.nn.combinators import Sequential + import numpy as np + import jax + + def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): + model = Sequential( + LinearJax(shape=(n_neurons, n_neurons)), + LIFJax(n_neurons), + ) + input_static = jax.numpy.array(np.random.rand(batch_size, n_steps, n_neurons)) + + def apply(mod, input): + output = mod(input) + return output + + apply = jax.jit(apply, backend="cpu") + + def loss(mod, input): + out, _, _ = mod(input) + return out.sum() + + grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="cpu") + + # - Force compilation + apply(model, input_static) + grad_loss(model, input_static) + + return dict( + model=model, + jit_fwd=apply, + jit_bwd=grad_loss, + n_neurons=n_neurons, + input=input_static, + ) + + def forward_fn(bench_dict): + model, apply, input = ( + bench_dict["model"], + bench_dict["jit_fwd"], + bench_dict["input"], + ) + bench_dict["output"] = apply(model, input)[0] + return bench_dict + + def backward_fn(bench_dict): + model, loss, input = ( + bench_dict["model"], + bench_dict["jit_bwd"], + bench_dict["input"], + ) + loss(model, input) + + benchmark_title = f"Rockpool jax CPU accel.
v{rockpool.__version__}" + + return prepare_fn, forward_fn, backward_fn, benchmark_title + + +def rockpool_jax_lif_jit_gpu(): + import rockpool + from rockpool.nn.modules import LIFJax, LinearJax + from rockpool.nn.combinators import Sequential + import numpy as np + import jax + + def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): + model = Sequential( + LinearJax(shape=(n_neurons, n_neurons)), + LIFJax(n_neurons), + ) + input_static = np.random.rand(batch_size, n_steps, n_neurons) + + def apply(mod, input): + return mod(input) + + apply = jax.jit(apply, backend="gpu") + + def loss(mod, input): + out, _, _ = mod(input) + return out.sum() + + grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="gpu") + + # - Force compilation + apply(model, input_static) + grad_loss(model, input_static) + + return dict( + model=model, + jit_fwd=apply, + jit_bwd=grad_loss, + n_neurons=n_neurons, + input=input_static, + ) + + def forward_fn(bench_dict): + model, apply, input = ( + bench_dict["model"], + bench_dict["jit_fwd"], + bench_dict["input"], + ) + bench_dict["output"] = apply(model, input)[0] + return bench_dict + + def backward_fn(bench_dict): + model, loss, input = ( + bench_dict["model"], + bench_dict["jit_bwd"], + bench_dict["input"], + ) + loss(model, input) + + benchmark_title = f"Rockpool jax GPU accel.
v{rockpool.__version__}" + + return prepare_fn, forward_fn, backward_fn, benchmark_title + + +def rockpool_jax_lif_jit_tpu(): + import rockpool + from rockpool.nn.modules import LIFJax, LinearJax + from rockpool.nn.combinators import Sequential + import numpy as np + import jax + + def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): + model = Sequential( + LinearJax(shape=(n_neurons, n_neurons)), + LIFJax(n_neurons), + ) + input_static = np.random.rand(batch_size, n_steps, n_neurons) + + def apply(mod, input): + return mod(input) + + apply = jax.jit(apply, backend="tpu") + + def loss(mod, input): + out, _, _ = mod(input) + return out.sum() + + grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="tpu") + + # - Force compilation + apply(model, input_static) + grad_loss(model, input_static) + + return dict( + model=model, + jit_fwd=apply, + jit_bwd=grad_loss, + n_neurons=n_neurons, + input=input_static, + ) + + def forward_fn(bench_dict): + model, apply, input = ( + bench_dict["model"], + bench_dict["jit_fwd"], + bench_dict["input"], + ) + bench_dict["output"] = apply(model, input)[0] + return bench_dict + + def backward_fn(bench_dict): + model, loss, input = ( + bench_dict["model"], + bench_dict["jit_bwd"], + bench_dict["input"], + ) + loss(model, input) + + benchmark_title = f"Rockpool jax TPU accel.
v{rockpool.__version__}" + + return prepare_fn, forward_fn, backward_fn, benchmark_title + + +def rockpool_native(): + import rockpool + from rockpool.nn.modules import LIF, Linear + from rockpool.nn.combinators import Sequential + import numpy as np + + def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): + model = Sequential( + Linear(shape=(n_neurons, n_neurons)), + LIF(n_neurons), + ) + input_static = np.random.rand(batch_size, n_steps, n_neurons) + + model(input_static) + + return dict(model=model, n_neurons=n_neurons, input=input_static) + + def forward_fn(bench_dict): + mod, input_static = bench_dict["model"], bench_dict["input"] + bench_dict["output"] = mod(input_static)[0] + return bench_dict + + def backward_fn(bench_dict): + pass + + benchmark_title = f"Rockpool numpy
v{rockpool.__version__}" + + return prepare_fn, forward_fn, backward_fn, benchmark_title + + def rockpool_torch(): import torch @@ -13,7 +221,7 @@ def rockpool_torch(): from rockpool.nn.combinators import Sequential import rockpool - benchmark_title = f"Rockpool
v{rockpool.__version__}" + benchmark_title = f"Rockpool torch
v{rockpool.__version__}" def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): model = Sequential( @@ -73,6 +281,83 @@ def backward_fn(bench_dict): return prepare_fn, forward_fn, backward_fn, benchmark_title +def rockpool_torch_cuda_graph(): + from rockpool.nn.modules import LIFTorch + import torch + + class StepPWL(torch.autograd.Function): + """ + Heaviside step function with piece-wise linear surrogate to use as spike-generation surrogate + """ + + @staticmethod + def forward( + ctx, + x, + threshold=torch.tensor(1.0), + window=torch.tensor(0.5), + max_spikes_per_dt=torch.tensor(2.0**16), + ): + ctx.save_for_backward(x, threshold) + ctx.window = window + nr_spikes = ((x >= threshold) * torch.floor(x / threshold)).float() + # nr_spikes[nr_spikes > max_spikes_per_dt] = max_spikes_per_dt.float() + clamp_bool = (nr_spikes > max_spikes_per_dt).float() + nr_spikes -= (nr_spikes - max_spikes_per_dt.float()) * clamp_bool + return nr_spikes + + @staticmethod + def backward(ctx, grad_output): + x, threshold = ctx.saved_tensors + grad_x = grad_threshold = grad_window = grad_max_spikes_per_dt = None + + mask = x >= (threshold - ctx.window) + if ctx.needs_input_grad[0]: + grad_x = grad_output / threshold * mask + + if ctx.needs_input_grad[1]: + grad_threshold = -x * grad_output / (threshold**2) * mask + + return grad_x, grad_threshold, grad_window, grad_max_spikes_per_dt + + def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device): + mod = LIFTorch( + n_neurons, spike_generation_fn=StepPWL, max_spikes_per_dt=2.0**16 + ).cuda() + input_static = torch.randn(batch_size, n_steps, n_neurons, device=device) + + # - Warm up the CUDA stream + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + y_pred, _, _ = mod(input_static) + + # - Capture the graph + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + static_y_pred, _, _ = mod(input_static) + + return dict( + model=g, input=input_static, n_neurons=n_neurons, output=static_y_pred + ) + + def forward_fn(bench_dict): + model = bench_dict["model"] + model.replay() + return bench_dict + + def backward_fn(bench_dict): + output = bench_dict["output"] + loss = output.sum() + loss.backward(retain_graph=True) + + benchmark_title = f"LIFTorch using CUDA graph replay acceleration" + + return prepare_fn, forward_fn, backward_fn, benchmark_title + + def sinabs(): import torch from torch import nn @@ -493,6 +778,11 @@ def backward_fn(bench_dict): benchmarks = { "rockpool_torch": rockpool_torch, "rockpool_exodus": rockpool_exodus, + "rockpool_native": rockpool_native, + "rockpool_jax_cpu": rockpool_jax_lif_jit_cpu, + "rockpool_jax_gpu": rockpool_jax_lif_jit_gpu, + "rockpool_jax_tpu": rockpool_jax_lif_jit_tpu, + "rockpool_torch_cuda_graph": rockpool_torch_cuda_graph, "sinabs": sinabs, "sinabs_exodus": sinabs_exodus, "norse": norse, diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh index 94ce8882..d0766fa5 100755 --- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh +++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/run_benchmarks.sh @@ -6,7 +6,7 @@ echo "running benchmarks with batch size $1" echo "framework,neurons,forward,backward,memory" >data.csv -for benchmark in "rockpool_torch" "rockpool_exodus" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do +for benchmark in "rockpool_jax_cpu" "rockpool_native" "rockpool_torch" "rockpool_exodus" "rockpool_jax_tpu" "rockpool_torch_cuda_graph" "rockpool_jax_gpu" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do python3 ./run_benchmark.py $benchmark $1 done diff --git a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py index 17555fc6..87f9b12f 100644 --- a/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py +++ b/content/english/blog/spiking-neural-network-framework-benchmarking/docker/utils.py @@ -4,8 +4,9 @@ from time import time + def log_result(framework, neurons, forward, backward, memory): - with open('data.csv', 'a') as f: + with open("data.csv", "a") as f: w = writer(f) w.writerow([framework, neurons, forward, backward, memory]) @@ -57,26 +58,28 @@ def benchmark_framework( try: # - Prepare benchmark bench_dict = prepare_fn( - batch_size=batch_size, - n_steps=n_steps, + batch_size=batch_size, + n_steps=n_steps, n_neurons=n_neurons, - n_layers=n_layers, - device=device + n_layers=n_layers, + device=device, ) # - Forward pass forward_times.append(timeit(lambda: forward_fn(bench_dict))) bench_dict = forward_fn(bench_dict) - assert bench_dict["output"].shape == bench_dict["input"].shape + shape_out = bench_dict["output"].shape + shape_in = bench_dict["input"].shape + assert ( + shape_out == shape_in + ), f"Output shape {shape_out} != input shape {shape_in}" # - Backward pass backward_times.append(timeit(lambda: backward_fn(bench_dict))) except Exception as e: # - Fail nicely with a warning if a benchmark dies - warnings.warn( - f"Benchmark {benchmark_desc} failed with error {str(e)}." - ) + warnings.warn(f"Benchmark {benchmark_desc} failed with error {str(e)}.") # - No results for this run forward_times.append([])