Skip to content

Commit 3284ee3

Browse files
committed
run
1 parent bf28032 commit 3284ee3

File tree

1 file changed

+43
-39
lines changed

1 file changed

+43
-39
lines changed

bindsnet/network/network.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from bindsnet.network.monitors import AbstractMonitor
88
from bindsnet.network.nodes import CSRMNodes, Nodes
99
from bindsnet.network.topology import AbstractConnection
10+
from bindsnet.utils import stream
1011

1112

1213
def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "Network":
@@ -384,49 +385,52 @@ def run(
384385
current_inputs.update(self._get_inputs())
385386

386387
for l in self.layers:
387-
# Update each layer of nodes.
388-
if l in inputs:
389-
if l in current_inputs:
390-
current_inputs[l] += inputs[l][t]
391-
else:
392-
current_inputs[l] = inputs[l][t]
393-
394-
if one_step:
395-
# Get input to this layer (one-step mode).
396-
current_inputs.update(self._get_inputs(layers=[l]))
388+
with stream():
389+
# Update each layer of nodes.
390+
if l in inputs:
391+
if l in current_inputs:
392+
current_inputs[l] += inputs[l][t]
393+
else:
394+
current_inputs[l] = inputs[l][t]
395+
396+
if one_step:
397+
# Get input to this layer (one-step mode).
398+
current_inputs.update(self._get_inputs(layers=[l]))
399+
400+
# Inject voltage to neurons.
401+
inject_v = injects_v.get(l, None)
402+
if inject_v is not None:
403+
if inject_v.ndimension() == 1:
404+
self.layers[l].v += inject_v
405+
else:
406+
self.layers[l].v += inject_v[t]
397407

398-
# Inject voltage to neurons.
399-
inject_v = injects_v.get(l, None)
400-
if inject_v is not None:
401-
if inject_v.ndimension() == 1:
402-
self.layers[l].v += inject_v
408+
if l in current_inputs:
409+
self.layers[l].forward(x=current_inputs[l])
403410
else:
404-
self.layers[l].v += inject_v[t]
405-
406-
if l in current_inputs:
407-
self.layers[l].forward(x=current_inputs[l])
408-
else:
409-
self.layers[l].forward(
410-
x=torch.zeros(
411-
self.layers[l].s.shape, device=self.layers[l].s.device
411+
self.layers[l].forward(
412+
x=torch.zeros(
413+
self.layers[l].s.shape, device=self.layers[l].s.device
414+
)
412415
)
413-
)
414-
415-
# Clamp neurons to spike.
416-
clamp = clamps.get(l, None)
417-
if clamp is not None:
418-
if clamp.ndimension() == 1:
419-
self.layers[l].s[:, clamp] = 1
420-
else:
421-
self.layers[l].s[:, clamp[t]] = 1
422416

423-
# Clamp neurons not to spike.
424-
unclamp = unclamps.get(l, None)
425-
if unclamp is not None:
426-
if unclamp.ndimension() == 1:
427-
self.layers[l].s[:, unclamp] = 0
428-
else:
429-
self.layers[l].s[:, unclamp[t]] = 0
417+
# Clamp neurons to spike.
418+
clamp = clamps.get(l, None)
419+
if clamp is not None:
420+
if clamp.ndimension() == 1:
421+
self.layers[l].s[:, clamp] = 1
422+
else:
423+
self.layers[l].s[:, clamp[t]] = 1
424+
425+
# Clamp neurons not to spike.
426+
unclamp = unclamps.get(l, None)
427+
if unclamp is not None:
428+
if unclamp.ndimension() == 1:
429+
self.layers[l].s[:, unclamp] = 0
430+
else:
431+
self.layers[l].s[:, unclamp[t]] = 0
432+
433+
torch.cuda.synchronize()
430434

431435
for c in self.connections:
432436
flad_m = False

0 commit comments

Comments
 (0)