Skip to content

Commit fdf9bcb

Browse files
committed
Init
1 parent 96479e6 commit fdf9bcb

File tree

2 files changed

+409
-388
lines changed

2 files changed

+409
-388
lines changed

bindsnet/network/network.py

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from bindsnet.network.monitors import AbstractMonitor
88
from bindsnet.network.nodes import CSRMNodes, Nodes
99
from bindsnet.network.topology import AbstractConnection
10+
from torch.multiprocessing.spawn import spawn
1011

1112

1213
def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "Network":
@@ -28,6 +29,52 @@ def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "N
2829
return network
2930

3031

32+
def update_layer(_, self, l, t):
33+
# Update each layer of nodes.
34+
if l in self.inputs:
35+
if l in self.current_inputs:
36+
self.current_inputs[l] += self.inputs[l][t]
37+
else:
38+
self.current_inputs[l] = self.inputs[l][t]
39+
40+
if self.one_step:
41+
# Get input to this layer (one-step mode).
42+
self.current_inputs.update(self._get_inputs(layers=[l]))
43+
44+
# Inject voltage to neurons.
45+
inject_v = self.injects_v.get(l, None)
46+
if inject_v is not None:
47+
if inject_v.ndimension() == 1:
48+
self.layers[l].v += inject_v
49+
else:
50+
self.layers[l].v += inject_v[t]
51+
52+
if l in self.current_inputs:
53+
self.layers[l].forward(x=self.current_inputs[l])
54+
else:
55+
self.layers[l].forward(
56+
x=torch.zeros(
57+
self.layers[l].s.shape, device=self.layers[l].s.device
58+
)
59+
)
60+
61+
# Clamp neurons to spike.
62+
clamp = self.clamps.get(l, None)
63+
if clamp is not None:
64+
if clamp.ndimension() == 1:
65+
self.layers[l].s[:, clamp] = 1
66+
else:
67+
self.layers[l].s[:, clamp[t]] = 1
68+
69+
# Clamp neurons not to spike.
70+
unclamp = self.unclamps.get(l, None)
71+
if unclamp is not None:
72+
if unclamp.ndimension() == 1:
73+
self.layers[l].s[:, unclamp] = 0
74+
else:
75+
self.layers[l].s[:, unclamp[t]] = 0
76+
77+
3178
class Network(torch.nn.Module):
3279
# language=rst
3380
"""
@@ -383,50 +430,23 @@ def run(
383430
if not one_step:
384431
current_inputs.update(self._get_inputs())
385432

386-
for l in self.layers:
387-
# Update each layer of nodes.
388-
if l in inputs:
389-
if l in current_inputs:
390-
current_inputs[l] += inputs[l][t]
391-
else:
392-
current_inputs[l] = inputs[l][t]
393-
394-
if one_step:
395-
# Get input to this layer (one-step mode).
396-
current_inputs.update(self._get_inputs(layers=[l]))
433+
processes = []
434+
self.inputs = inputs
435+
self.current_inputs = current_inputs
436+
self.one_step = one_step
437+
self.injects_v = injects_v
438+
self.unclamps = unclamps
439+
self.clamps = clamps
397440

398-
# Inject voltage to neurons.
399-
inject_v = injects_v.get(l, None)
400-
if inject_v is not None:
401-
if inject_v.ndimension() == 1:
402-
self.layers[l].v += inject_v
403-
else:
404-
self.layers[l].v += inject_v[t]
405-
406-
if l in current_inputs:
407-
self.layers[l].forward(x=current_inputs[l])
408-
else:
409-
self.layers[l].forward(
410-
x=torch.zeros(
411-
self.layers[l].s.shape, device=self.layers[l].s.device
412-
)
413-
)
441+
for l in self.layers:
442+
processes.append(
443+
spawn(update_layer, args=(self, l, t), join=False)
444+
)
414445

415-
# Clamp neurons to spike.
416-
clamp = clamps.get(l, None)
417-
if clamp is not None:
418-
if clamp.ndimension() == 1:
419-
self.layers[l].s[:, clamp] = 1
420-
else:
421-
self.layers[l].s[:, clamp[t]] = 1
446+
for p in processes:
447+
p.join()
422448

423-
# Clamp neurons not to spike.
424-
unclamp = unclamps.get(l, None)
425-
if unclamp is not None:
426-
if unclamp.ndimension() == 1:
427-
self.layers[l].s[:, unclamp] = 0
428-
else:
429-
self.layers[l].s[:, unclamp[t]] = 0
449+
print(t)
430450

431451
for c in self.connections:
432452
flad_m = False

0 commit comments

Comments
 (0)