-
Notifications
You must be signed in to change notification settings - Fork 283
Open
Description
Issue type
- Bug Report
- Feature Request
- Help wanted
- Other
SpikingJelly version
0.0.0.0.14
Description
torch.util.checkpoint seems to expect the function to be same given same inputs.
but for SNN models like LIF, there is often state tensors like 'v'.
torch.utils.checkpoint computes forwards twice, resulting in 'v' being updated twice in a single step.
Does anyone know how to correctly use torch.utils.checkpoint with spikingjelly?
Minimal code to reproduce the error/bug
# adapted from https://github.com/pytorch/pytorch/issues/96136
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
""" Test checkpoint with variable buffers. """
from typing import Optional
import torch
from torch.nn import Linear, Sequential
from torch.optim import SGD
import torch.utils.checkpoint
class RNNCellInternalBuffer(torch.nn.Module):
def __init__(self, size):
super().__init__()
# self.register_buffer("v", torch.zeros(size))
self.v = torch.zeros(size)
def forward(self, x):
self.v = self.v.to(x.device) + x
return self.v
class CheckpointedSequential(Sequential):
def __init__(self, checkpoint: bool, *args):
super().__init__(*args)
self.checkpoint = checkpoint
def forward(self, input):
if self.checkpoint:
return torch.utils.checkpoint.checkpoint(super().forward, input, use_reentrant=False)
else:
return super().forward(input)
def get_model(checkpointed, sizes):
assert checkpointed in [True, False], checkpointed
model = CheckpointedSequential(checkpointed, RNNCellInternalBuffer(sizes), Linear(3, 2))
return model
# copied from https://github.com/facebookresearch/fairscale
def objects_are_equal(
a,
b,
raise_exception: bool = False,
dict_key: Optional[str] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
"""
if type(a) is not type(b):
if raise_exception:
raise ValueError(f"type mismatch {type(a)} vs. {type(b)}")
return False
if isinstance(a, dict):
if set(a.keys()) != set(b.keys()):
if raise_exception:
raise ValueError(f"keys mismatch {a.keys()} vs. {b.keys()}")
return False
for k in a.keys():
if not objects_are_equal(a[k], b[k], raise_exception, k):
return False
return True
elif isinstance(a, (list, tuple, set)):
if len(a) != len(b):
if raise_exception:
raise ValueError(f"length mismatch {len(a)} vs. {len(b)}")
return False
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
# assert_close doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match:
if raise_exception:
msg = f"sizes: {a.size()} vs. {b.size()}, "
msg += f"types: {a.dtype} vs. {b.dtype}, "
msg += f"device: {a.device} vs. {b.device}"
raise AssertionError(msg)
else:
return False
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
if dict_key and isinstance(e, AssertionError):
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg) from None
else:
raise e
else:
return False
else:
return a == b
def test_checkpointed_variable_buffer(device):
# Get input, ref, checkpoint models and make them equal.
sizes = (2, 2, 3, 3)
in_data = torch.rand(*sizes).to(device)
# # these match
# m_ref = get_model(True, sizes).to(device)
# m_cpt = get_model(True, sizes).to(device)
# # these match
# m_ref = get_model(True, sizes).to(device)
# m_cpt = get_model(True, sizes).to(device)
m_ref = get_model(False, sizes).to(device)
m_cpt = get_model(True, sizes).to(device)
m_cpt.load_state_dict(m_ref.state_dict())
assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
# Needed due to checkpointing.
in_data.requires_grad = True
for model in (m_ref, m_cpt):
optim = SGD(model.parameters(), lr=0.1)
out = model(in_data)
out.sum().backward()
optim.step()
assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
test_checkpointed_variable_buffer("cuda")
Metadata
Metadata
Assignees
Labels
No labels