Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError when using CaptumExplainer after GNNExplainer #9653

Open
he-jesse opened this issue Sep 11, 2024 · 0 comments
Open

RuntimeError when using CaptumExplainer after GNNExplainer #9653

he-jesse opened this issue Sep 11, 2024 · 0 comments
Labels

Comments

@he-jesse
Copy link

🐛 Describe the bug

Trying to run CaptumExplainer after using GNNExplainer throws a Runtime error. However, running CaptumExplainer before running GNNExplainer does not. (A similar thing happens with GraphMaskExplainer as well.) The expected result is that both GNNExplainer and CaptumExplainer successfully return explanations regardless of the order in which they are called.

Below is the MWE:

from torch_geometric.nn import GCN
from torch_geometric.explain import Explainer, GNNExplainer, CaptumExplainer
from torch_geometric.datasets import FakeDataset

dataset = FakeDataset()
data = dataset[0]
model = GCN(64, 16, 2, 1)

gnnexplainer = Explainer(
    model=model,
    algorithm=GNNExplainer(),
    explanation_type='model',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='raw',
    )
)
captumexplainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='raw',
    )
)
gnnexplainer(data.x, data.edge_index, index=0)
captumexplainer(data.x, data.edge_index, index=0)

Here is the full traceback:

Traceback (most recent call last):
  File "c:\Users\jesse\Documents\gnn-project\bug_mwe.py", line 32, in <module>
    captumexplainer(data.x, data.edge_index, index=0)
  File "C:\Users\jesse\miniconda3\Lib\site-packages\torch_geometric\explain\explainer.py", line 205, in __call__
    explanation = self.algorithm(
                  ^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\torch_geometric\explain\algorithm\captum_explainer.py", line 170, in forward
    attributions = self.attribution_method_instance.attribute(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\captum\log\__init__.py", line 42, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\captum\attr\_core\integrated_gradients.py", line 274, in attribute
    attributions = _batch_attribution(
                   ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\captum\attr\_utils\batching.py", line 78, in _batch_attribution
    current_attr = attr_method._attribute(
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\captum\attr\_core\integrated_gradients.py", line 351, in _attribute
    grads = self.gradient_func(
            ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\captum\_utils\gradient.py", line 119, in compute_gradients
    grads = torch.autograd.grad(torch.unbind(outputs), inputs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\miniconda3\Lib\site-packages\torch\autograd\__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Versions

Collecting environment information...
PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.3 | packaged by Anaconda, Inc. | (main, May  6 2024, 19:42:21) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-11-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 528.97
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=4001
DeviceID=CPU0
Family=107
L2CacheSize=8192
L2CacheSpeed=
Manufacturer=AuthenticAMD
MaxClockSpeed=4001
Name=AMD Ryzen 9 7940HS w/ Radeon 780M Graphics
ProcessorType=3
Revision=29697

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] torch_geometric==2.5.2
[pip3] torchaudio==2.2.2
[pip3] torchvision==0.17.2
[conda] blas                      1.0                         mkl
[conda] mkl                       2023.1.0         h6b88ed4_46358
[conda] mkl-service               2.4.0           py312h2bbff1b_1
[conda] mkl_fft                   1.3.8           py312h2bbff1b_0
[conda] mkl_random                1.2.4           py312h59b6b97_0
[conda] numpy                     1.26.4          py312hfd52020_0
[conda] numpy-base                1.26.4          py312h4dde369_0
[conda] pyg                       2.5.2           py312_torch_2.2.0_cu121    pyg
[conda] pytorch                   2.2.2           py3.12_cuda12.1_cudnn8_0    pytorch
[conda] pytorch-cuda              12.1                 hde6ce7c_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.2.2                    pypi_0    pypi
[conda] torchvision               0.17.2                   pypi_0    pypi
@he-jesse he-jesse added the bug label Sep 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant