Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug with Instance Norm #9647

Open
davidireland3 opened this issue Sep 7, 2024 · 0 comments
Open

Bug with Instance Norm #9647

davidireland3 opened this issue Sep 7, 2024 · 0 comments
Labels

Comments

@davidireland3
Copy link

🐛 Describe the bug

I am having an issue where a graph embedding for the same graph comes out different when using InstanceNorm as part of the DeepConvLayer wrapper. The problem relates to whether or not track_running_stats is set to False (which is when the error occurs), vs. when it is set to True when the error doesn't seem to occur. I'm not sure if I have a misunderstanding of how instance normalisation should work in a graph but I thoguht that for a feature vector of a single graph, we would be normalising the first feature by the average of that feature for all nodes in the graph. So, this should be unaffected by what else is in the batch. But maybe I am wrong in thinking this? either way, I think there is either a bug in the code or my expectation of how this should work is wrong. Here's a script that should reproduce the error, you can switch the track_running_stats as a model argument to see the different results.

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import random
import numpy as np

import torch.nn as nn
from torch_geometric.nn.norm import InstanceNorm
from torch_geometric.nn.pool import global_mean_pool
from torch_geometric.nn.conv import RGCNConv
from torch_geometric.nn.models import DeepGCNLayer

random.seed(10)
np.random.seed(10)
torch.random.manual_seed(10)


class SimpleDeepGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations, num_layers=3,
                 track_running_stats=True):
        super(SimpleDeepGCN, self).__init__()

        self.node_encoder = nn.Linear(in_channels, hidden_channels)

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            conv = RGCNConv(hidden_channels, hidden_channels, num_relations)
            norm = InstanceNorm(in_channels=hidden_channels, track_running_stats=track_running_stats, affine=True)
            act = nn.ReLU()
            layer = DeepGCNLayer(conv, norm, act, block='res')
            self.layers.append(layer)

        self.out_layer = nn.Linear(hidden_channels, out_channels)

    def forward(self, batch):
        x, edge_index, edge_type = batch.x, batch.edge_index, batch.edge_type
        x = self.node_encoder(x)

        for layer in self.layers:
            x = layer(x, edge_index, edge_type)

        x = self.out_layer(x)
        return global_mean_pool(x, batch.batch)


def create_relational_graph(num_nodes, num_edges, num_node_features, num_edge_types):
    x = torch.randn(num_nodes, num_node_features)
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    edge_type = torch.randint(0, num_edge_types, (num_edges,))
    return Data(x=x, edge_index=edge_index, edge_type=edge_type, binary_hash=f"{np.random.randint(0, 10000)}")


def create_batch_graphs(num_graphs, num_node_features, num_edge_types):
    num_nodes = np.random.randint(50, 100)
    num_edges = np.random.randint(100, 200)
    return [create_relational_graph(num_nodes, num_edges, num_node_features, num_edge_types)
            for _ in range(num_graphs)]


# Create a fixed graph for testing
fixed_graph = create_relational_graph(num_nodes=20, num_edges=40, num_node_features=256, num_edge_types=3)

# Create two batches of random graphs
batch1 = create_batch_graphs(num_graphs=5, num_node_features=256, num_edge_types=3)
batch2 = create_batch_graphs(num_graphs=8, num_node_features=256, num_edge_types=3)

# Append fixed graph to each batch
batch1.append(fixed_graph)
batch2.append(fixed_graph)

# Create data loaders
loader1 = DataLoader(batch1, batch_size=len(batch1), shuffle=False)
loader2 = DataLoader(batch2, batch_size=len(batch2), shuffle=False)

model = SimpleDeepGCN(256, 1000, 128, 3, track_running_stats=True)


def compare_embeddings(emb1, emb2, rtol=1e-5, atol=1e-8):
    """
    Compare two numpy arrays of embeddings.

    Args:
    emb1, emb2: numpy arrays of the same shape
    rtol: relative tolerance parameter
    atol: absolute tolerance parameter

    Returns:
    bool: True if the arrays are equal within the given tolerance, False otherwise
    """
    return np.allclose(emb1, emb2, rtol=rtol, atol=atol)


# Test the embeddings
model.eval()
with torch.no_grad():
    for batch1, batch2 in zip(loader1, loader2):
        # Forward pass for batch1
        output1 = model.forward(batch1).cpu().numpy()
        fixed_graph_emb1 = output1[-1]  # Assuming the fixed graph is at the end

        # Forward pass for batch2
        output2 = model.forward(batch2).cpu().numpy()
        fixed_graph_emb2 = output2[-1]  # Assuming the fixed graph is at the end

        # Compare the embeddings
        are_embeddings_equal = compare_embeddings(fixed_graph_emb1, fixed_graph_emb2)

        print(f"Are the embeddings of the fixed graph equal in both batches? {are_embeddings_equal}")

Versions

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 23357  100 23357    0     0  94579      0 --:--:-- --:--:-- --:--:-- 94947
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A

Python version: 3.10.14 (main, May  6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] pytorch-metric-learning==2.4.1
[pip3] torch==2.1.0
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[pip3] torchaudio==2.1.0
[pip3] torchmetrics==1.3.0
[conda] numpy                     1.26.1                   pypi_0    pypi
[conda] pytorch-metric-learning   2.4.1                    pypi_0    pypi
[conda] torch                     2.1.0                    pypi_0    pypi
[conda] torch-cluster             1.6.3                    pypi_0    pypi
[conda] torch-geometric           2.5.3                    pypi_0    pypi
[conda] torch-scatter             2.1.2                    pypi_0    pypi
[conda] torch-sparse              0.6.18                   pypi_0    pypi
[conda] torch-spline-conv         1.2.2                    pypi_0    pypi
[conda] torchaudio                2.1.0                    pypi_0    pypi
[conda] torchmetrics              1.3.0                    pypi_0    pypi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant