Explaining graph regression on a homogeneous graph #9035
Replies: 5 comments 3 replies
-
It might be that
|
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
Hello, I am having the same issue trying to explain an heterogeneous graph model for binary classification in the new version of torch_geometric index = torch.tensor([mapping_gene['TUBGCP2'], mapping_phen['HP:0012758']]) # Index of the edge classification i want to explain
edge_label_index = data["phenotype", "related_to", "gene"].edge_index
explainer = Explainer(
model=model,
algorithm=CaptumExplainer('IntegratedGradients'),
explanation_type='model',
model_config=dict(
mode='binary_classification',
task_level='edge',
return_type='probs',
),
node_mask_type='attributes',
edge_mask_type='object',
threshold_config=dict(
threshold_type='topk',
value=20,
),
)
explanation = explainer(
data.x_dict,
data.edge_index_dict,
index = index,
edge_label_index = edge_label_index
) The output for this is: ---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[12], line 20
1 index = torch.tensor([mapping_gene['TUBGCP2'], mapping_phen['HP:0012758']])
3 explainer = Explainer(
4 model=model,
5 algorithm=CaptumExplainer('IntegratedGradients'),
(...)
17 ),
18 )
---> 20 explanation = explainer(
21 data.x_dict,
22 data.edge_index_dict,
23 index = index,
24 edge_label_index = edge_label_index
25 )
26 explanation.visualize_feature_importance(top_k=10, feat_labels = feat_labels)
File ~/torch/lib/python3.11/site-packages/torch_geometric/explain/explainer.py:196, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
192 if target is not None:
193 warnings.warn(
194 f"The 'target' should not be provided for the explanation "
195 f"type '{self.explanation_type.value}'")
--> 196 prediction = self.get_prediction(x, edge_index, **kwargs)
197 target = self.get_target(prediction)
199 if isinstance(index, int):
File ~/torch/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/torch/lib/python3.11/site-packages/torch_geometric/explain/explainer.py:115, in Explainer.get_prediction(self, *args, **kwargs)
112 self.model.eval()
114 with torch.no_grad():
--> 115 out = self.model(*args, **kwargs)
117 self.model.train(training)
119 return out
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/phenolinker/src/analysis/gnn.py:48, in HeteroGNN.forward(self, x_dict, edge_index_dict, edge_label_index)
47 def forward(self, x_dict, edge_index_dict, edge_label_index):
---> 48 x_dict = self.encode(x_dict, edge_index_dict)
49 pred = self.decode(x_dict, edge_label_index)
50 return pred
File ~/phenolinker/src/analysis/gnn.py:34, in HeteroGNN.encode(self, x_dict, edge_index_dict)
32 def encode(self, x_dict, edge_index_dict):
33 for conv in self.convs:
---> 34 x_dict = conv(x_dict, edge_index_dict)
35 x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
36 x_dict = self.final(x_dict, edge_index_dict)
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/torch/lib/python3.11/site-packages/torch_geometric/nn/conv/hetero_conv.py:158, in HeteroConv.forward(self, *args_dict, **kwargs_dict)
155 if not has_edge_level_arg:
156 continue
--> 158 out = conv(*args, **kwargs)
160 if dst not in out_dict:
161 out_dict[dst] = [out]
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/torch/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/torch/lib/python3.11/site-packages/torch_geometric/nn/conv/sage_conv.py:134, in SAGEConv.forward(self, x, edge_index, size)
131 x = (self.lin(x[0]).relu(), x[1])
133 # propagate_type: (x: OptPairTensor)
--> 134 out = self.propagate(edge_index, x=x, size=size)
135 out = self.lin_l(out)
137 x_r = x[1]
File ~/torch/lib/python3.11/site-packages/torch_geometric/nn/conv/message_passing.py:565, in MessagePassing.propagate(self, edge_index, size, **kwargs)
562 if self.explain:
563 explain_msg_kwargs = self.inspector.collect_param_data(
564 'explain_message', coll_dict)
--> 565 out = self.explain_message(out, **explain_msg_kwargs)
567 aggr_kwargs = self.inspector.collect_param_data(
568 'aggregate', coll_dict)
569 for hook in self._aggregate_forward_pre_hooks.values():
File ~/torch/lib/python3.11/site-packages/torch_geometric/nn/conv/message_passing.py:779, in MessagePassing.explain_message(self, inputs, dim_size)
776 edge_mask = self._edge_mask
778 if edge_mask is None:
--> 779 raise ValueError("Could not find a pre-defined 'edge_mask' "
780 "to explain. Did you forget to initialize it?")
782 if self._apply_sigmoid:
783 edge_mask = edge_mask.sigmoid()
ValueError: Could not find a pre-defined 'edge_mask' to explain. Did you forget to initialize it? The same code worked for the |
Beta Was this translation helpful? Give feedback.
-
Mh, that's looks like a regression :( Is there any easy way for you to come up with a reproducible example? Alternatively, would it be possible for you to debug this on your end? It looks like |
Beta Was this translation helpful? Give feedback.
-
Yeah you are right, it would be a regression as i am retrieving a single score for the probability of the edge to be in the graph, i previously had regression in the explainer but thought that i was doing it wrong. I changed it back to explainer = Explainer(
model=model,
algorithm=CaptumExplainer('IntegratedGradients'),
explanation_type='model',
model_config=dict(
mode='regression',
task_level='edge',
return_type='raw',
),
node_mask_type='attributes',
edge_mask_type='object',
threshold_config=dict(
threshold_type='topk',
value=20,
),
) but now it gives me a new error (when i tried this code last year it was working): ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], line 20
1 index = torch.tensor([mapping_gene['TUBGCP2'], mapping_phen['HP:0012758']])
3 explainer = Explainer(
4 model=model,
5 algorithm=CaptumExplainer('IntegratedGradients'),
(...)
17 ),
18 )
---> 20 explanation = explainer(
21 data.x_dict,
22 data.edge_index_dict,
23 index = index,
24 edge_label_index = edge_label_index
25 )
26 explanation.visualize_feature_importance(top_k=10, feat_labels = feat_labels)
...
File ~/torch/lib/python3.11/site-packages/torch_geometric/nn/conv/message_passing.py:756, in MessagePassing.explain(self, explain)
751 self.inspector.inspect_signature(self.explain_message, exclude=[0])
752 self._user_args = self.inspector.get_flat_param_names(
753 funcs=['message', 'explain_message', 'aggregate', 'update'],
754 exclude=self.special_args,
755 )
--> 756 self.propagate = self.__class__._orig_propagate.__get__(
757 self, MessagePassing)
758 else:
759 self._user_args = self.inspector.get_flat_param_names(
760 funcs=['message', 'aggregate', 'update'],
761 exclude=self.special_args,
762 )
AttributeError: type object 'SAGEConv' has no attribute '_orig_propagate' If u need it i can put the whole code of the model, but it just uses HeteroConv layers using SAGEConv for each type of edge (2 types of nodes), and retrieves a probability using a sigmoid function on the product of the embeddings of the nodes. |
Beta Was this translation helpful? Give feedback.
-
used torch_geometric version
2.5.0
Dear PyG afficionados,
I've trained a simple GCN for node regression task on a homogeneous graph following the suggestions here (#3794) -- changing loss to
torch.nn.functional.mse_loss
.To provide explanations of the model, I am following the steps from this tutorial https://pytorch-geometric.readthedocs.io/en/latest/tutorial/explain.html#explaining-graph-regression-on-a-homogeneous-graph
Unfortunately, I am unable to train an explainer model since a week. I have tried many things, here are the problems I am facing:
Data cannot be loaded from the loader
Training the explainer model for one epoch
I've tried to train the explainer model on the full dataset without mini-batch loader, however, I get uninitialized edge_mask error
Training the explainer model on CPU hits AttributeError
By the way, pushing the model and the dataset to CPU seems to trigger another error
I would really appreciate any help here!
Best,
Asan
Beta Was this translation helpful? Give feedback.
All reactions