Skip to content

Commit 7e7c456

Browse files
committed
connections
1 parent 3284ee3 commit 7e7c456

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

bindsnet/network/network.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -433,29 +433,30 @@ def run(
433433
torch.cuda.synchronize()
434434

435435
for c in self.connections:
436-
flad_m = False
437-
if A_Minus != None and ((isinstance(A_Minus, float)) or (c in A_Minus)):
438-
if A_MD:
439-
kwargs["a_minus"] = A_Minus[c]
440-
else:
441-
kwargs["a_minus"] = A_Minus
442-
flad_m = True
436+
with stream():
437+
flad_m = False
438+
if A_Minus != None and ((isinstance(A_Minus, float)) or (c in A_Minus)):
439+
if A_MD:
440+
kwargs["a_minus"] = A_Minus[c]
441+
else:
442+
kwargs["a_minus"] = A_Minus
443+
flad_m = True
443444

444-
flad_p = False
445-
if A_Plus != None and ((isinstance(A_Plus, float)) or (c in A_Plus)):
446-
if A_PD:
447-
kwargs["a_plus"] = A_Plus[c]
448-
else:
449-
kwargs["a_plus"] = A_Plus
450-
flad_p = True
451-
452-
self.connections[c].update(
453-
mask=masks.get(c, None), learning=self.learning, **kwargs
454-
)
455-
if flad_m:
456-
kwargs.pop("a_minus")
457-
if flad_p:
458-
kwargs.pop("a_plus")
445+
flad_p = False
446+
if A_Plus != None and ((isinstance(A_Plus, float)) or (c in A_Plus)):
447+
if A_PD:
448+
kwargs["a_plus"] = A_Plus[c]
449+
else:
450+
kwargs["a_plus"] = A_Plus
451+
flad_p = True
452+
453+
self.connections[c].update(
454+
mask=masks.get(c, None), learning=self.learning, **kwargs
455+
)
456+
if flad_m:
457+
kwargs.pop("a_minus")
458+
if flad_p:
459+
kwargs.pop("a_plus")
459460

460461
# # Get input to all layers.
461462
# current_inputs.update(self._get_inputs())

0 commit comments

Comments
 (0)