Skip to content

torch-geometric version has to be up to date #3

@sfluegel05

Description

@sfluegel05

I had some trouble running chebai-graph with torch_geometric version 2.3.1.

When loading a checkpoint trained on 2.6.1, I got a mismatch in the dimensions of the ResGatedGraphConv layers (this one.

Apparently, in version 2.3.1, the shape of the tensor only depends on the in_channel, while in later versions, in_channels and edge_attr are added.

Here is the output for in_channels=158, edge_attr=7:

RuntimeError: Error(s) in loading state_dict for ResGatedGraphConvNetGraphPred:
    size mismatch for gnn.convs.0.lin_key.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
    size mismatch for gnn.convs.0.lin_query.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
    size mismatch for gnn.convs.0.lin_value.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
    size mismatch for gnn.convs.1.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.1.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.1.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.2.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.2.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.2.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.3.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.3.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.convs.3.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
    size mismatch for gnn.final_conv.lin_key.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).
    size mismatch for gnn.final_conv.lin_query.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).
    size mismatch for gnn.final_conv.lin_value.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).

Todo

  • Require torch_geometric>=2.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions