Graph Classification

Graph classification is a downstream classification task conducted at the graph level. Once node representations are learned by a GNN, one can obtain the graph-level representation and then perform graph-level classification. To facilitate the graph classification task, we provide commonly used implementations of the graph classification prediction modules.

FeedForwardNN

This is a high-level graph classification prediction module which consists of a graph pooling component and a multilayer perceptron (MLP). Users can specify important hyperparameters such as input_size, num_class and hidden_size (i.e., list of hidden sizes for each dense layer). The FeedForwardNN class calls the FeedForwardNNLayer API which implments MLP.

class FeedForwardNN(GraphClassifierBase):
    def __init__(self, input_size, num_class, hidden_size, activation=None, graph_pool_type='max_pool', **kwargs):
        super(FeedForwardNN, self).__init__()

        if not activation:
            activation = nn.ReLU()

        if graph_pool_type == 'avg_pool':
            self.graph_pool = AvgPooling()
        elif graph_pool_type == 'max_pool':
            self.graph_pool = MaxPooling(**kwargs)
        else:
            raise RuntimeError('Unknown graph pooling type: {}'.format(graph_pool_type))

        self.classifier = FeedForwardNNLayer(input_size, num_class, hidden_size, activation)

AvgPooling

This is the average pooling module which applies average pooling over the nodes in the graph. It takes batched GraphData as input and returns a feature tensor containing a vector for each graph in the batch.

class AvgPooling(PoolingBase):
    def __init__(self):
        super(AvgPooling, self).__init__()

    def forward(self, graph, feat):
        graph_list = from_batch(graph)
        output_feat = []
        for g in graph_list:
            output_feat.append(g.node_features[feat].mean(dim=0))

        output_feat = torch.stack(output_feat, 0)

        return output_feat

MaxPooling

This is the max pooling module which applies max pooling over the nodes in the graph. It takes batched GraphData as input and returns a feature tensor containing a vector for each graph in the batch. An optional linear projection can be applied to node embeddings before conducting max pooling.

class MaxPooling(PoolingBase):
    def __init__(self, dim=None, use_linear_proj=False):
        super(MaxPooling, self).__init__()
        if use_linear_proj:
            assert dim is not None, "dim should be specified when use_linear_proj is set to True"
            self.linear = nn.Linear(dim, dim, bias=False)
        else:
            self.linear = None

    def forward(self, graph, feat):
        graph_list = from_batch(graph)
        output_feat = []
        for g in graph_list:
            feat_tensor = g.node_features[feat]
            if self.linear is not None:
                feat_tensor = self.linear(feat_tensor)

            output_feat.append(torch.max(feat_tensor, dim=0)[0])

        output_feat = torch.stack(output_feat, 0)

        return output_feat