in "08.rainbow.ipynb" , DQNAgent.update_model() function:
PER: importance sampling before average
loss = torch.mean(elementwise_loss * weights)
elementwise_loss: shape(128,)
weights: shape(128,1)
elementwise_loss * weights: shape: (128, 128)
I think the expected shape of "elementwise_loss * weights" should be (128, )
hope your answer