Dynamic Graph Construction

Unlike static graph construction which is performed during preprocessing, dynamic graph construction operates by jointly learning the graph structure and graph representation on the fly. The ultimate goal is to learn the optimized graph structures and representations with respect to certain downstream prediction task. As shown in the figure below, given a set of data points which can stand for various NLP elements such as words, sentences and documents, we first apply graph similarity metric learning which aims to capture the pair-wise node similarity and returns a fully-connected weighted graph. Then, we can optionally apply graph sparsification to obtain a sparse graph. When the initial graph topology is available, we can choose to combine the initial graph topology and the implicit learned graph topology to obtain a better graph topology for the downstream task.

../../_images/dynamic_graph_overall.pdf

DynamicGraphConstructionBase

Before we introduce the two built-in dynamic graph construction classes, let’s first talk about DynamicGraphConstructionBase which is the base class for dynamic graph construction. This base class implements several important components shared by various dynamic graph construction approaches. We will introduce each of the components defined in the base class next.

The embedding method aims to compute initial node embeddings which will be later used for dynamic graph construction. This method calls the EmbeddingConstruction instance which is initialized in GraphConstructionBase where GraphConstructionBase is the base class of all the graph construction classes including both static and dynamic ones.

The compute_similarity_metric method aims to compute pair-wise node similarity in the node embedding space where the node embeddings are created by the above embedding method. The output of this method is a weighted adjacency matrix which is corresponding to a fully-connected graph. Various similarity metric functions such as weighted_cosine, attention and rbf_kernel are supported. Below is the implementation of this method.

def compute_similarity_metric(self, node_emb, node_mask=None):
    if self.sim_metric_type == 'attention':
        attention = 0
        for _ in range(len(self.linear_sims)):
            node_vec_t = torch.relu(self.linear_sims[_](node_emb))
            attention += torch.matmul(node_vec_t, node_vec_t.transpose(-1, -2))

        attention /= len(self.linear_sims)
    elif self.sim_metric_type == 'weighted_cosine':
        expand_weight_tensor = self.weight.unsqueeze(1)
        if len(node_emb.shape) == 3:
            expand_weight_tensor = expand_weight_tensor.unsqueeze(1)

        node_vec_t = node_emb.unsqueeze(0) * expand_weight_tensor
        node_vec_norm = F.normalize(node_vec_t, p=2, dim=-1)
        attention = torch.matmul(node_vec_norm, node_vec_norm.transpose(-1, -2)).mean(0)
    elif self.sim_metric_type == 'gat_attention':
        attention = []
        for _ in range(len(self.linear_sims1)):
            a_input1 = self.linear_sims1[_](node_emb)
            a_input2 = self.linear_sims2[_](node_emb)
            attention.append(self.leakyrelu(a_input1 + a_input2.transpose(-1, -2)))

        attention = torch.mean(torch.stack(attention, 0), 0)
    elif self.sim_metric_type == 'rbf_kernel':
        dist_weight = torch.mm(self.weight, self.weight.transpose(-1, -2))
        attention = self._compute_distance_matrix(node_emb, dist_weight)
        attention = torch.exp(-0.5 * attention * (self.precision_inv_dis**2))
    elif self.sim_metric_type == 'cosine':
        node_vec_norm = node_emb.div(torch.norm(node_emb, p=2, dim=-1, keepdim=True))
        attention = torch.mm(node_vec_norm, node_vec_norm.transpose(-1, -2)).detach()

    if node_mask is not None:
        if torch.__version__ < '1.3.0':
            attention = attention.masked_fill_(~(node_mask == 1.), self.mask_off_val)
        else:
            attention = attention.masked_fill_(~node_mask.bool(), self.mask_off_val)

    return attention

The sparsify_graph method aims to obtain a sparse graph from the above fully-connected graph. Various graph sparsification options such as kNN sparsification and epsilon-neighborhood sparsification are supported. Below is the implementation of this method.

def sparsify_graph(self, adj):
    if self.epsilon_neigh is not None:
        adj = self._build_epsilon_neighbourhood(adj, self.epsilon_neigh)

    if self.top_k_neigh is not None:
        adj = self._build_knn_neighbourhood(adj, self.top_k_neigh)

    return adj

The compute_graph_regularization method aims to compute regularization terms for the learned graph topology. Various graph regularization losses such as smoothness, connectivity and sparsity are supported. Below is the implementation of this method.

def compute_graph_regularization(self, adj, node_feat):
    graph_reg = 0
    if not self.smoothness_ratio in (0, None):
        for i in range(adj.shape[0]):
            L = torch.diagflat(torch.sum(adj[i], -1)) - adj[i]
            graph_reg += self.smoothness_ratio * torch.trace(torch.mm(node_feat[i].transpose(-1, -2), torch.mm(L, node_feat[i]))) / int(np.prod(adj.shape))

    if not self.connectivity_ratio in (0, None):
        ones_vec = torch.ones(adj.shape[:-1]).to(adj.device)
        graph_reg += -self.connectivity_ratio * torch.matmul(ones_vec.unsqueeze(1), torch.log(torch.matmul(adj, ones_vec.unsqueeze(-1)) + VERY_SMALL_NUMBER)).sum() / adj.shape[0] / adj.shape[-1]

    if not self.sparsity_ratio in (0, None):
        graph_reg += self.sparsity_ratio * torch.sum(torch.pow(adj, 2)) / int(np.prod(adj.shape))

    return graph_reg

Node Embedding Based Dynamic Graph Construction

For node embedding based dynamic graph construction, we aim to learn the graph structure from a set of node embeddings. The NodeEmbeddingBasedGraphConstruction class inherits the DynamicGraphConstructionBase base class which implements several aforementioned important components (e.g., compute_similarity_metric, sparsify_graph). The topology method in NodeEmbeddingBasedGraphConstruction implements the logic of learning a graph topology from initial node embeddings, as shown below:

def topology(self, graph):
    node_emb = graph.batch_node_features["node_feat"]
    node_mask = (graph.batch_node_features["token_id"] != Vocab.PAD)

    raw_adj = self.compute_similarity_metric(node_emb, node_mask)
    raw_adj = self.sparsify_graph(raw_adj)
    graph_reg = self.compute_graph_regularization(raw_adj, node_emb)

    if self.sim_metric_type in ('rbf_kernel', 'weighted_cosine'):
        assert raw_adj.min().item() >= 0, 'adjacency matrix must be non-negative!'
        adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-1, keepdim=True), min=torch.finfo(torch.float32).eps)
        reverse_adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-2, keepdim=True), min=torch.finfo(torch.float32).eps)
    elif self.sim_metric_type == 'cosine':
        raw_adj = (raw_adj > 0).float()
        adj = normalize_adj(raw_adj)
        reverse_adj = adj
    else:
        adj = torch.softmax(raw_adj, dim=-1)
        reverse_adj = torch.softmax(raw_adj, dim=-2)

    graph = convert_adj_to_graph(graph, adj, reverse_adj, 0)
    graph.graph_attributes['graph_reg'] = graph_reg

    return graph

Node Embedding Based Refined Dynamic Graph Construction

Unlike the node embedding based metric learning, node embedding based refined graph metric learning in addition utilizes the intrinsic graph structure which potentially still carries rich and useful information regarding the optimal graph structure for the downstream task. It basically computes a linear combination of the normalized graph Laplacian of the intrinsic graph and the normalized adjacency matrix of the learned implicit graph.

NodeEmbeddingBasedRefinedGraphConstruction class also inherits the DynamicGraphConstructionBase base class. The topology method in NodeEmbeddingBasedRefinedGraphConstruction implements the logic of combining the initial graph topology and the learned implicit graph topology, as shown below:

def topology(self, graph, init_norm_adj):
    node_emb = graph.batch_node_features["node_feat"]
    node_mask = (graph.batch_node_features["token_id"] != Vocab.PAD)

    raw_adj = self.compute_similarity_metric(node_emb, node_mask)
    raw_adj = self.sparsify_graph(raw_adj)
    graph_reg = self.compute_graph_regularization(raw_adj, node_emb)

    if self.sim_metric_type in ('rbf_kernel', 'weighted_cosine'):
        assert raw_adj.min().item() >= 0, 'adjacency matrix must be non-negative!'
        adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-1, keepdim=True), min=torch.finfo(torch.float32).eps)
        reverse_adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-2, keepdim=True), min=torch.finfo(torch.float32).eps)
    elif self.sim_metric_type == 'cosine':
        raw_adj = (raw_adj > 0).float()
        adj = normalize_adj(raw_adj)
        reverse_adj = adj
    else:
        adj = torch.softmax(raw_adj, dim=-1)
        reverse_adj = torch.softmax(raw_adj, dim=-2)

    if self.alpha_fusion is not None:
        adj = torch.sparse.FloatTensor.add((1 - self.alpha_fusion) * adj, self.alpha_fusion * init_norm_adj)
        reverse_adj = torch.sparse.FloatTensor.add((1 - self.alpha_fusion) * reverse_adj, self.alpha_fusion * init_norm_adj)

    graph = convert_adj_to_graph(graph, adj, reverse_adj, 0)
    graph.graph_attributes['graph_reg'] = graph_reg

    return graph