Skip to content

Updating benchmarks to add Rockpool JAX-accelerated benchmarks #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04

RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential
RUN apt update && apt install -y git wget python3 python3-pip python3-dev cmake build-essential pkg-config libhdf5-serial-dev

WORKDIR /app

Expand All @@ -12,11 +12,13 @@ COPY . .
# 3. install dependencies ¯\_(ツ)_/¯
RUN pip3 install git+https://github.com/fangwei123456/spikingjelly.git
RUN pip3 install poetry

RUN git clone https://github.com/lava-nc/lava-dl.git && \
cd lava-dl && \
poetry config virtualenvs.in-project true && \
poetry install && \
source /app/lava-dl/.venv/bin/activate && \
poetry install

RUN . /app/lava-dl/.venv/bin/activate && \
pip3 install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip3 install torch torchvision torchaudio && \
pip3 install -r ./requirements.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,214 @@
import os


def rockpool_jax_lif_jit_cpu():
import rockpool
from rockpool.nn.modules import LIFJax, LinearJax
from rockpool.nn.combinators import Sequential
import numpy as np
import jax

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
LinearJax(shape=(n_neurons, n_neurons)),
LIFJax(n_neurons),
)
input_static = jax.numpy.array(np.random.rand(batch_size, n_steps, n_neurons))

def apply(mod, input):
output = mod(input)
return output

apply = jax.jit(apply, backend="cpu")

def loss(mod, input):
out, _, _ = mod(input)
return out.sum()

grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="cpu")

# - Force compilation
apply(model, input_static)
grad_loss(model, input_static)

return dict(
model=model,
jit_fwd=apply,
jit_bwd=grad_loss,
n_neurons=n_neurons,
input=input_static,
)

def forward_fn(bench_dict):
model, apply, input = (
bench_dict["model"],
bench_dict["jit_fwd"],
bench_dict["input"],
)
bench_dict["output"] = apply(model, input)[0]
return bench_dict

def backward_fn(bench_dict):
model, loss, input = (
bench_dict["model"],
bench_dict["jit_bwd"],
bench_dict["input"],
)
loss(model, input)

benchmark_title = f"Rockpool jax CPU accel.<br>v{rockpool.__version__}"

return prepare_fn, forward_fn, backward_fn, benchmark_title


def rockpool_jax_lif_jit_gpu():
import rockpool
from rockpool.nn.modules import LIFJax, LinearJax
from rockpool.nn.combinators import Sequential
import numpy as np
import jax

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
LinearJax(shape=(n_neurons, n_neurons)),
LIFJax(n_neurons),
)
input_static = np.random.rand(batch_size, n_steps, n_neurons)

def apply(mod, input):
return mod(input)

apply = jax.jit(apply, backend="gpu")

def loss(mod, input):
out, _, _ = mod(input)
return out.sum()

grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="gpu")

# - Force compilation
apply(model, input_static)
grad_loss(model, input_static)

return dict(
model=model,
jit_fwd=apply,
jit_bwd=grad_loss,
n_neurons=n_neurons,
input=input_static,
)

def forward_fn(bench_dict):
model, apply, input = (
bench_dict["model"],
bench_dict["jit_fwd"],
bench_dict["input"],
)
bench_dict["output"] = apply(model, input)[0]
return bench_dict

def backward_fn(bench_dict):
model, loss, input = (
bench_dict["model"],
bench_dict["jit_bwd"],
bench_dict["input"],
)
loss(model, input)

benchmark_title = f"Rockpool jax GPU accel.<br>v{rockpool.__version__}"

return prepare_fn, forward_fn, backward_fn, benchmark_title


def rockpool_jax_lif_jit_tpu():
import rockpool
from rockpool.nn.modules import LIFJax, LinearJax
from rockpool.nn.combinators import Sequential
import numpy as np
import jax

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
LinearJax(shape=(n_neurons, n_neurons)),
LIFJax(n_neurons),
)
input_static = np.random.rand(batch_size, n_steps, n_neurons)

def apply(mod, input):
return mod(input)

apply = jax.jit(apply, backend="tpu")

def loss(mod, input):
out, _, _ = mod(input)
return out.sum()

grad_loss = jax.jit(jax.grad(loss, allow_int=True), backend="tpu")

# - Force compilation
apply(model, input_static)
grad_loss(model, input_static)

return dict(
model=model,
jit_fwd=apply,
jit_bwd=grad_loss,
n_neurons=n_neurons,
input=input_static,
)

def forward_fn(bench_dict):
model, apply, input = (
bench_dict["model"],
bench_dict["jit_fwd"],
bench_dict["input"],
)
bench_dict["output"] = apply(model, input)[0]
return bench_dict

def backward_fn(bench_dict):
model, loss, input = (
bench_dict["model"],
bench_dict["jit_bwd"],
bench_dict["input"],
)
loss(model, input)

benchmark_title = f"Rockpool jax TPU accel.<br>v{rockpool.__version__}"

return prepare_fn, forward_fn, backward_fn, benchmark_title


def rockpool_native():
import rockpool
from rockpool.nn.modules import LIF, Linear
from rockpool.nn.combinators import Sequential
import numpy as np

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
Linear(shape=(n_neurons, n_neurons)),
LIF(n_neurons),
)
input_static = np.random.rand(batch_size, n_steps, n_neurons)

model(input_static)

return dict(model=model, n_neurons=n_neurons, input=input_static)

def forward_fn(bench_dict):
mod, input_static = bench_dict["model"], bench_dict["input"]
bench_dict["output"] = mod(input_static)[0]
return bench_dict

def backward_fn(bench_dict):
pass

benchmark_title = f"Rockpool numpy<br>v{rockpool.__version__}"

return prepare_fn, forward_fn, backward_fn, benchmark_title


def rockpool_torch():
import torch

Expand All @@ -13,7 +221,7 @@ def rockpool_torch():
from rockpool.nn.combinators import Sequential
import rockpool

benchmark_title = f"Rockpool<br>v{rockpool.__version__}"
benchmark_title = f"Rockpool torch<br>v{rockpool.__version__}"

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
model = Sequential(
Expand Down Expand Up @@ -73,6 +281,83 @@ def backward_fn(bench_dict):
return prepare_fn, forward_fn, backward_fn, benchmark_title


def rockpool_torch_cuda_graph():
from rockpool.nn.modules import LIFTorch
import torch

class StepPWL(torch.autograd.Function):
"""
Heaviside step function with piece-wise linear surrogate to use as spike-generation surrogate
"""

@staticmethod
def forward(
ctx,
x,
threshold=torch.tensor(1.0),
window=torch.tensor(0.5),
max_spikes_per_dt=torch.tensor(2.0**16),
):
ctx.save_for_backward(x, threshold)
ctx.window = window
nr_spikes = ((x >= threshold) * torch.floor(x / threshold)).float()
# nr_spikes[nr_spikes > max_spikes_per_dt] = max_spikes_per_dt.float()
clamp_bool = (nr_spikes > max_spikes_per_dt).float()
nr_spikes -= (nr_spikes - max_spikes_per_dt.float()) * clamp_bool
return nr_spikes

@staticmethod
def backward(ctx, grad_output):
x, threshold = ctx.saved_tensors
grad_x = grad_threshold = grad_window = grad_max_spikes_per_dt = None

mask = x >= (threshold - ctx.window)
if ctx.needs_input_grad[0]:
grad_x = grad_output / threshold * mask

if ctx.needs_input_grad[1]:
grad_threshold = -x * grad_output / (threshold**2) * mask

return grad_x, grad_threshold, grad_window, grad_max_spikes_per_dt

def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
mod = LIFTorch(
n_neurons, spike_generation_fn=StepPWL, max_spikes_per_dt=2.0**16
).cuda()
input_static = torch.randn(batch_size, n_steps, n_neurons, device=device)

# - Warm up the CUDA stream
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
y_pred, _, _ = mod(input_static)

# - Capture the graph
g = torch.cuda.CUDAGraph()

with torch.cuda.graph(g):
static_y_pred, _, _ = mod(input_static)

return dict(
model=g, input=input_static, n_neurons=n_neurons, output=static_y_pred
)

def forward_fn(bench_dict):
model = bench_dict["model"]
model.replay()
return bench_dict

def backward_fn(bench_dict):
output = bench_dict["output"]
loss = output.sum()
loss.backward(retain_graph=True)

benchmark_title = f"LIFTorch using CUDA graph replay acceleration"

return prepare_fn, forward_fn, backward_fn, benchmark_title


def sinabs():
import torch
from torch import nn
Expand Down Expand Up @@ -493,6 +778,11 @@ def backward_fn(bench_dict):
benchmarks = {
"rockpool_torch": rockpool_torch,
"rockpool_exodus": rockpool_exodus,
"rockpool_native": rockpool_native,
"rockpool_jax_cpu": rockpool_jax_lif_jit_cpu,
"rockpool_jax_gpu": rockpool_jax_lif_jit_gpu,
"rockpool_jax_tpu": rockpool_jax_lif_jit_tpu,
"rockpool_torch_cuda_graph": rockpool_torch_cuda_graph,
"sinabs": sinabs,
"sinabs_exodus": sinabs_exodus,
"norse": norse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ echo "running benchmarks with batch size $1"

echo "framework,neurons,forward,backward,memory" >data.csv

for benchmark in "rockpool_torch" "rockpool_exodus" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do
for benchmark in "rockpool_jax_cpu" "rockpool_native" "rockpool_torch" "rockpool_exodus" "rockpool_jax_tpu" "rockpool_torch_cuda_graph" "rockpool_jax_gpu" "sinabs" "sinabs_exodus" "norse" "snntorch" "spikingjelly" "spikingjelly_cupy" "lava" "spyx_full" "spyx_half"; do
python3 ./run_benchmark.py $benchmark $1
done

Expand Down
Loading