Skip to content

Why return torch.mean(loss)? #16

@edshkim98

Description

@edshkim98

@yaringal Hi, I have a question about your multi-task loss function.
Below you return a loss as torch.mean(loss), but if i undersatnd this function correctly, loss is just a single tensor value and not a list, so torch.mean(loss) will be same as loss. What was your motivation behind using torch.mean(loss)?
Thank you!

def criterion(y_pred, y_true, log_vars):
  loss = 0
  for i in range(len(y_pred)):
    precision = torch.exp(-log_vars[i])
    diff = (y_pred[i]-y_true[i])**2.
    loss += torch.sum(precision * diff + log_vars[i], -1)
  return torch.mean(loss)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions