7
7
from bindsnet .network .monitors import AbstractMonitor
8
8
from bindsnet .network .nodes import CSRMNodes , Nodes
9
9
from bindsnet .network .topology import AbstractConnection
10
+ from torch .multiprocessing .spawn import spawn
10
11
11
12
12
13
def load (file_name : str , map_location : str = "cpu" , learning : bool = None ) -> "Network" :
@@ -28,6 +29,52 @@ def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "N
28
29
return network
29
30
30
31
32
+ def update_layer (_ , self , l , t ):
33
+ # Update each layer of nodes.
34
+ if l in self .inputs :
35
+ if l in self .current_inputs :
36
+ self .current_inputs [l ] += self .inputs [l ][t ]
37
+ else :
38
+ self .current_inputs [l ] = self .inputs [l ][t ]
39
+
40
+ if self .one_step :
41
+ # Get input to this layer (one-step mode).
42
+ self .current_inputs .update (self ._get_inputs (layers = [l ]))
43
+
44
+ # Inject voltage to neurons.
45
+ inject_v = self .injects_v .get (l , None )
46
+ if inject_v is not None :
47
+ if inject_v .ndimension () == 1 :
48
+ self .layers [l ].v += inject_v
49
+ else :
50
+ self .layers [l ].v += inject_v [t ]
51
+
52
+ if l in self .current_inputs :
53
+ self .layers [l ].forward (x = self .current_inputs [l ])
54
+ else :
55
+ self .layers [l ].forward (
56
+ x = torch .zeros (
57
+ self .layers [l ].s .shape , device = self .layers [l ].s .device
58
+ )
59
+ )
60
+
61
+ # Clamp neurons to spike.
62
+ clamp = self .clamps .get (l , None )
63
+ if clamp is not None :
64
+ if clamp .ndimension () == 1 :
65
+ self .layers [l ].s [:, clamp ] = 1
66
+ else :
67
+ self .layers [l ].s [:, clamp [t ]] = 1
68
+
69
+ # Clamp neurons not to spike.
70
+ unclamp = self .unclamps .get (l , None )
71
+ if unclamp is not None :
72
+ if unclamp .ndimension () == 1 :
73
+ self .layers [l ].s [:, unclamp ] = 0
74
+ else :
75
+ self .layers [l ].s [:, unclamp [t ]] = 0
76
+
77
+
31
78
class Network (torch .nn .Module ):
32
79
# language=rst
33
80
"""
@@ -383,50 +430,23 @@ def run(
383
430
if not one_step :
384
431
current_inputs .update (self ._get_inputs ())
385
432
386
- for l in self .layers :
387
- # Update each layer of nodes.
388
- if l in inputs :
389
- if l in current_inputs :
390
- current_inputs [l ] += inputs [l ][t ]
391
- else :
392
- current_inputs [l ] = inputs [l ][t ]
393
-
394
- if one_step :
395
- # Get input to this layer (one-step mode).
396
- current_inputs .update (self ._get_inputs (layers = [l ]))
433
+ processes = []
434
+ self .inputs = inputs
435
+ self .current_inputs = current_inputs
436
+ self .one_step = one_step
437
+ self .injects_v = injects_v
438
+ self .unclamps = unclamps
439
+ self .clamps = clamps
397
440
398
- # Inject voltage to neurons.
399
- inject_v = injects_v .get (l , None )
400
- if inject_v is not None :
401
- if inject_v .ndimension () == 1 :
402
- self .layers [l ].v += inject_v
403
- else :
404
- self .layers [l ].v += inject_v [t ]
405
-
406
- if l in current_inputs :
407
- self .layers [l ].forward (x = current_inputs [l ])
408
- else :
409
- self .layers [l ].forward (
410
- x = torch .zeros (
411
- self .layers [l ].s .shape , device = self .layers [l ].s .device
412
- )
413
- )
441
+ for l in self .layers :
442
+ processes .append (
443
+ spawn (update_layer , args = (self , l , t ), join = False )
444
+ )
414
445
415
- # Clamp neurons to spike.
416
- clamp = clamps .get (l , None )
417
- if clamp is not None :
418
- if clamp .ndimension () == 1 :
419
- self .layers [l ].s [:, clamp ] = 1
420
- else :
421
- self .layers [l ].s [:, clamp [t ]] = 1
446
+ for p in processes :
447
+ p .join ()
422
448
423
- # Clamp neurons not to spike.
424
- unclamp = unclamps .get (l , None )
425
- if unclamp is not None :
426
- if unclamp .ndimension () == 1 :
427
- self .layers [l ].s [:, unclamp ] = 0
428
- else :
429
- self .layers [l ].s [:, unclamp [t ]] = 0
449
+ print (t )
430
450
431
451
for c in self .connections :
432
452
flad_m = False
0 commit comments