Dynamic Graph Construction¶
Unlike static graph construction which is performed during preprocessing, dynamic graph construction operates by jointly learning the graph structure and graph representation on the fly. The ultimate goal is to learn the optimized graph structures and representations with respect to certain downstream prediction task. As shown in the figure below, given a set of data points which can stand for various NLP elements such as words, sentences and documents, we first apply graph similarity metric learning which aims to capture the pair-wise node similarity and returns a fully-connected weighted graph. Then, we can optionally apply graph sparsification to obtain a sparse graph. When the initial graph topology is available, we can choose to combine the initial graph topology and the implicit learned graph topology to obtain a better graph topology for the downstream task.
DynamicGraphConstructionBase¶
Before we introduce the two built-in dynamic graph construction classes, let’s
first talk about DynamicGraphConstructionBase
which is the base class
for dynamic graph construction. This base class implements several important
components shared by various dynamic graph construction approaches. We will
introduce each of the components defined in the base class next.
The embedding
method aims to compute initial node embeddings which will be
later used for dynamic graph construction. This method calls the EmbeddingConstruction
instance which is initialized in GraphConstructionBase
where GraphConstructionBase
is the base class of all the graph construction classes including both static and dynamic ones.
The compute_similarity_metric
method aims to compute pair-wise node similarity in the node embedding
space where the node embeddings are created by the above embedding
method. The output of this method
is a weighted adjacency matrix which is corresponding to a fully-connected graph.
Various similarity metric functions such as weighted_cosine, attention and rbf_kernel are supported.
Below is the implementation of this method.
def compute_similarity_metric(self, node_emb, node_mask=None):
if self.sim_metric_type == 'attention':
attention = 0
for _ in range(len(self.linear_sims)):
node_vec_t = torch.relu(self.linear_sims[_](node_emb))
attention += torch.matmul(node_vec_t, node_vec_t.transpose(-1, -2))
attention /= len(self.linear_sims)
elif self.sim_metric_type == 'weighted_cosine':
expand_weight_tensor = self.weight.unsqueeze(1)
if len(node_emb.shape) == 3:
expand_weight_tensor = expand_weight_tensor.unsqueeze(1)
node_vec_t = node_emb.unsqueeze(0) * expand_weight_tensor
node_vec_norm = F.normalize(node_vec_t, p=2, dim=-1)
attention = torch.matmul(node_vec_norm, node_vec_norm.transpose(-1, -2)).mean(0)
elif self.sim_metric_type == 'gat_attention':
attention = []
for _ in range(len(self.linear_sims1)):
a_input1 = self.linear_sims1[_](node_emb)
a_input2 = self.linear_sims2[_](node_emb)
attention.append(self.leakyrelu(a_input1 + a_input2.transpose(-1, -2)))
attention = torch.mean(torch.stack(attention, 0), 0)
elif self.sim_metric_type == 'rbf_kernel':
dist_weight = torch.mm(self.weight, self.weight.transpose(-1, -2))
attention = self._compute_distance_matrix(node_emb, dist_weight)
attention = torch.exp(-0.5 * attention * (self.precision_inv_dis**2))
elif self.sim_metric_type == 'cosine':
node_vec_norm = node_emb.div(torch.norm(node_emb, p=2, dim=-1, keepdim=True))
attention = torch.mm(node_vec_norm, node_vec_norm.transpose(-1, -2)).detach()
if node_mask is not None:
if torch.__version__ < '1.3.0':
attention = attention.masked_fill_(~(node_mask == 1.), self.mask_off_val)
else:
attention = attention.masked_fill_(~node_mask.bool(), self.mask_off_val)
return attention
The sparsify_graph
method aims to obtain a sparse graph from the above fully-connected graph.
Various graph sparsification options such as kNN sparsification and epsilon-neighborhood sparsification
are supported. Below is the implementation of this method.
def sparsify_graph(self, adj):
if self.epsilon_neigh is not None:
adj = self._build_epsilon_neighbourhood(adj, self.epsilon_neigh)
if self.top_k_neigh is not None:
adj = self._build_knn_neighbourhood(adj, self.top_k_neigh)
return adj
The compute_graph_regularization
method aims to compute regularization terms for the learned graph topology.
Various graph regularization losses such as smoothness, connectivity and sparsity are supported.
Below is the implementation of this method.
def compute_graph_regularization(self, adj, node_feat):
graph_reg = 0
if not self.smoothness_ratio in (0, None):
for i in range(adj.shape[0]):
L = torch.diagflat(torch.sum(adj[i], -1)) - adj[i]
graph_reg += self.smoothness_ratio * torch.trace(torch.mm(node_feat[i].transpose(-1, -2), torch.mm(L, node_feat[i]))) / int(np.prod(adj.shape))
if not self.connectivity_ratio in (0, None):
ones_vec = torch.ones(adj.shape[:-1]).to(adj.device)
graph_reg += -self.connectivity_ratio * torch.matmul(ones_vec.unsqueeze(1), torch.log(torch.matmul(adj, ones_vec.unsqueeze(-1)) + VERY_SMALL_NUMBER)).sum() / adj.shape[0] / adj.shape[-1]
if not self.sparsity_ratio in (0, None):
graph_reg += self.sparsity_ratio * torch.sum(torch.pow(adj, 2)) / int(np.prod(adj.shape))
return graph_reg
Node Embedding Based Dynamic Graph Construction¶
For node embedding based dynamic graph construction, we aim to learn the graph structure from a set of node embeddings.
The NodeEmbeddingBasedGraphConstruction
class inherits the DynamicGraphConstructionBase
base class which implements
several aforementioned important components (e.g., compute_similarity_metric
, sparsify_graph
).
The topology
method in NodeEmbeddingBasedGraphConstruction
implements the logic of learning a graph topology from
initial node embeddings, as shown below:
def topology(self, graph):
node_emb = graph.batch_node_features["node_feat"]
node_mask = (graph.batch_node_features["token_id"] != Vocab.PAD)
raw_adj = self.compute_similarity_metric(node_emb, node_mask)
raw_adj = self.sparsify_graph(raw_adj)
graph_reg = self.compute_graph_regularization(raw_adj, node_emb)
if self.sim_metric_type in ('rbf_kernel', 'weighted_cosine'):
assert raw_adj.min().item() >= 0, 'adjacency matrix must be non-negative!'
adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-1, keepdim=True), min=torch.finfo(torch.float32).eps)
reverse_adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-2, keepdim=True), min=torch.finfo(torch.float32).eps)
elif self.sim_metric_type == 'cosine':
raw_adj = (raw_adj > 0).float()
adj = normalize_adj(raw_adj)
reverse_adj = adj
else:
adj = torch.softmax(raw_adj, dim=-1)
reverse_adj = torch.softmax(raw_adj, dim=-2)
graph = convert_adj_to_graph(graph, adj, reverse_adj, 0)
graph.graph_attributes['graph_reg'] = graph_reg
return graph
Node Embedding Based Refined Dynamic Graph Construction¶
Unlike the node embedding based metric learning, node embedding based refined graph metric learning in addition utilizes the intrinsic graph structure which potentially still carries rich and useful information regarding the optimal graph structure for the downstream task. It basically computes a linear combination of the normalized graph Laplacian of the intrinsic graph and the normalized adjacency matrix of the learned implicit graph.
NodeEmbeddingBasedRefinedGraphConstruction
class also inherits the DynamicGraphConstructionBase
base class.
The topology
method in NodeEmbeddingBasedRefinedGraphConstruction
implements the logic of combining the initial
graph topology and the learned implicit graph topology, as shown below:
def topology(self, graph, init_norm_adj):
node_emb = graph.batch_node_features["node_feat"]
node_mask = (graph.batch_node_features["token_id"] != Vocab.PAD)
raw_adj = self.compute_similarity_metric(node_emb, node_mask)
raw_adj = self.sparsify_graph(raw_adj)
graph_reg = self.compute_graph_regularization(raw_adj, node_emb)
if self.sim_metric_type in ('rbf_kernel', 'weighted_cosine'):
assert raw_adj.min().item() >= 0, 'adjacency matrix must be non-negative!'
adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-1, keepdim=True), min=torch.finfo(torch.float32).eps)
reverse_adj = raw_adj / torch.clamp(torch.sum(raw_adj, dim=-2, keepdim=True), min=torch.finfo(torch.float32).eps)
elif self.sim_metric_type == 'cosine':
raw_adj = (raw_adj > 0).float()
adj = normalize_adj(raw_adj)
reverse_adj = adj
else:
adj = torch.softmax(raw_adj, dim=-1)
reverse_adj = torch.softmax(raw_adj, dim=-2)
if self.alpha_fusion is not None:
adj = torch.sparse.FloatTensor.add((1 - self.alpha_fusion) * adj, self.alpha_fusion * init_norm_adj)
reverse_adj = torch.sparse.FloatTensor.add((1 - self.alpha_fusion) * reverse_adj, self.alpha_fusion * init_norm_adj)
graph = convert_adj_to_graph(graph, adj, reverse_adj, 0)
graph.graph_attributes['graph_reg'] = graph_reg
return graph