Knowledge Graph Completion¶
The purpose of Knowledge Graph Completion (KGC) is to predict new triples on the basis of existing triples, so as to further extend KGs. KGC is usually considered as a link prediction task. Formally, the knowledge graph is represented by \(\mathcal{G} = (\mathcal{V}, \mathcal{E}, \mathcal{R})\), in which entities \(v_i \in \mathcal{V}\), edges \((v_s, r, v_o) \in \mathcal{E}\), and \(r \in \mathcal{R}\) is a relation type. This task scores for new facts (i.e. triples like \(\left \langle subject, relation, object \right \rangle\)) to determine how likely those edges are to belong to \(\mathcal{E}\).
KGC can be solved with an encoder-decoder framework. To encode the local neighborhood information of an entity, the encoder can be chosen from a variety of GNNs.
The decoder is a knowledge graph embedding model and can be regarded as a scoring
function. The most common decoders of knowledge graph completion includes
translation-based models (TransE), tensor factorization based models (DistMult,
ComplEx) and neural network base models (ConvE).
We implement DistMult
and ComplEx
in this library.
DistMult is a tensor factorization based models from paper Embedding entities and
relations for learning and inference in knowledge bases.
For DistMult
, the equation is:
In DistMult, every relation r is represented by a diagonal matrix \(M_r \in \mathbb{R}^{d \times d}\) and a triple is scored as \(f(s, r, o) = e_s^T M_r e_o\).
In our implementation, the subject embedding, relation embedding and all entity
embeddings are given as the forward(...)
input. Then, we compute the score logits
for all entity nodes using multi-class loss such as BCELoss()
or the predition
scores of positive/negative examples using pairwise Loss Function such as SoftplusLoss()
and SigmoidLoss()
. More details about the KG completion loss please refer to graph4nlp.loss.KGLoss.
class DistMult(KGCompletionBase):
def __init__(self,
super(DistMult, self).__init__()
self.loss_name = loss_name
self.classifier = DistMultLayer(input_dropout, loss_name)
def forward(self, input_graph: GraphData, e1_emb, rel_emb, all_node_emb, multi_label=None):
if multi_label is None:
input_graph.graph_attributes['logits'] = self.classifier(e1_emb,
all_node_emb) # [B, N]
input_graph.graph_attributes['logits'], input_graph.graph_attributes['p_score'], \
input_graph.graph_attributes['n_score'] = self.classifier(e1_emb,
# input_graph.graph_attributes['p_score']: [L_p]
# input_graph.graph_attributes['n_score']: [L_n]
# L_p + L_n == B * N
return input_graph
class DistMultLayer(KGCompletionLayerBase):
def __init__(self,
super(DistMultLayer, self).__init__()
self.inp_drop = nn.Dropout(input_dropout)
self.loss_name = loss_name
def forward(self,
# dropout
e1_emb = self.inp_drop(e1_emb)
rel_emb = self.inp_drop(rel_emb)
logits = * rel_emb,
all_node_emb.weight.transpose(1, 0))
if self.loss_name in ['SoftMarginLoss']:
# target labels are numbers selecting from -1 and 1.
pred = torch.tanh(logits)
# target labels are numbers selecting from 0 and 1.
pred = torch.sigmoid(logits)
if multi_label is not None:
idxs_pos = torch.nonzero(multi_label == 1.)
pred_pos = pred[idxs_pos[:, 0], idxs_pos[:, 1]]
idxs_neg = torch.nonzero(multi_label == 0.)
pred_neg = pred[idxs_neg[:, 0], idxs_neg[:, 1]]
return pred, pred_pos, pred_neg
return pred
ComplEx is proposed in paper Complex Embeddings for Simple Link Prediction.
For ComplEx
, the equation is:
\(Re()\) denotes the real part of a vector.
How to Combine KGC Decoder with GNN Encoder¶
The code below provides an end-to-end KGC model using GCN
as encoder and DistMult
as decoder:
from graph4nlp.pytorch.modules.graph_embedding.gcn import GCN
from graph4nlp.pytorch.modules.prediction.classification.kg_completion import DistMult
from torch.nn.init import xavier_normal_
class GCNDistMult(torch.nn.Module):
def __init__(self, args, num_entities, num_relations, num_layers=2):
super(GCNDistMult, self).__init__()
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=0)
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=0)
self.num_entities = num_entities
self.num_relations = num_relations
self.num_layers = num_layers
self.gnn = GCN(self.num_layers, args.embedding_dim, args.embedding_dim, args.embedding_dim,
args.direction_option, feat_drop=args.input_drop)
self.direction_option = args.direction_option
self.distmult = DistMult(args.input_drop, loss_name='BCELoss')
self.loss = torch.nn.BCELoss()
def init(self):
def forward(self, e1, rel, kg_graph=None):
X = torch.LongTensor([i for i in range(self.num_entities)]).to(e1.device)
kg_graph.node_features['node_feat'] = self.emb_e(X)
kg_graph = self.gnn(kg_graph)
e1_embedded = kg_graph.node_features['node_feat'][e1]
rel_embedded = self.emb_rel(rel)
e1_embedded = e1_embedded.squeeze()
rel_embedded = rel_embedded.squeeze()
kg_graph = self.distmult(kg_graph, e1_embedded, rel_embedded, self.emb_e)
logits = kg_graph.graph_attributes['logits']
return logits