-
Notifications
You must be signed in to change notification settings - Fork 170
Open
Description
testing hiv dataset (and my custom one using self-written loader code)
NOT encountered in ZINC dataset
line 70 in the first file should be 52 in the original code
Torch code unchanged
c:\Users\user\Desktop\pretrain-gnns-master\chem\pretrain_masking.py in train(args, model_list, loader, optimizer_list, device)
68 ## loss for nodes
69 pred_node = linear_pred_atoms(node_rep[batch.masked_atom_indices])
---> 70 loss = criterion(pred_node.double(), batch.mask_node_label[:,0])
71
72 acc_node = compute_accuracy(pred_node, batch.mask_node_label[:,0])
c:\Anaconda\envs\pretrain_gnn\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
--> 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)
c:\Anaconda\envs\pretrain_gnn\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
902 def forward(self, input, target):
903 return F.cross_entropy(input, target, weight=self.weight,
--> 904 ignore_index=self.ignore_index, reduction=self.reduction)
905
906
c:\Anaconda\envs\pretrain_gnn\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
1968 if size_average is not None or reduce is not None:
1969 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1970 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
1971
1972
c:\Anaconda\envs\pretrain_gnn\lib\site-packages\torch\nn\functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1788 .format(input.size(0), target.size(0)))
1789 if dim == 2:
-> 1790 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
1791 elif dim == 4:
1792 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at c:\a\w\1\s\tmp_conda_3.7_070403\conda\conda-bld\pytorch-cpu_1550387224787\work\aten\src\thnn\generic/ClassNLLCriterion.c:93
Metadata
Metadata
Assignees
Labels
No labels