File tree 1 file changed +23
-22
lines changed 1 file changed +23
-22
lines changed Original file line number Diff line number Diff line change @@ -433,29 +433,30 @@ def run(
433
433
torch .cuda .synchronize ()
434
434
435
435
for c in self .connections :
436
- flad_m = False
437
- if A_Minus != None and ((isinstance (A_Minus , float )) or (c in A_Minus )):
438
- if A_MD :
439
- kwargs ["a_minus" ] = A_Minus [c ]
440
- else :
441
- kwargs ["a_minus" ] = A_Minus
442
- flad_m = True
436
+ with stream ():
437
+ flad_m = False
438
+ if A_Minus != None and ((isinstance (A_Minus , float )) or (c in A_Minus )):
439
+ if A_MD :
440
+ kwargs ["a_minus" ] = A_Minus [c ]
441
+ else :
442
+ kwargs ["a_minus" ] = A_Minus
443
+ flad_m = True
443
444
444
- flad_p = False
445
- if A_Plus != None and ((isinstance (A_Plus , float )) or (c in A_Plus )):
446
- if A_PD :
447
- kwargs ["a_plus" ] = A_Plus [c ]
448
- else :
449
- kwargs ["a_plus" ] = A_Plus
450
- flad_p = True
451
-
452
- self .connections [c ].update (
453
- mask = masks .get (c , None ), learning = self .learning , ** kwargs
454
- )
455
- if flad_m :
456
- kwargs .pop ("a_minus" )
457
- if flad_p :
458
- kwargs .pop ("a_plus" )
445
+ flad_p = False
446
+ if A_Plus != None and ((isinstance (A_Plus , float )) or (c in A_Plus )):
447
+ if A_PD :
448
+ kwargs ["a_plus" ] = A_Plus [c ]
449
+ else :
450
+ kwargs ["a_plus" ] = A_Plus
451
+ flad_p = True
452
+
453
+ self .connections [c ].update (
454
+ mask = masks .get (c , None ), learning = self .learning , ** kwargs
455
+ )
456
+ if flad_m :
457
+ kwargs .pop ("a_minus" )
458
+ if flad_p :
459
+ kwargs .pop ("a_plus" )
459
460
460
461
# # Get input to all layers.
461
462
# current_inputs.update(self._get_inputs())
You can’t perform that action at this time.
0 commit comments