Embedding Construction

The embedding construction module aims to learn the initial node/edge embeddings for the input graph before being consumed by the subsequent GNN model.

EmbeddingConstruction

The EmbeddingConstruction class supports various strategies for initializing both single-token (i.e., containing single token) and multi-token (i.e., containing multiple tokens) items (i.e., node/edge). As shown in the below code piece, for both single-token and multi-token items, supported embedding strategies include w2v, w2v_bilstm, w2v_bigru, bert, bert_bilstm, bert_bigru, w2v_bert, w2v_bert_bilstm and w2v_bert_bigru.

assert emb_strategy in ('w2v', 'w2v_bilstm', 'w2v_bigru', 'bert', 'bert_bilstm', 'bert_bigru',
    'w2v_bert', 'w2v_bert_bilstm', 'w2v_bert_bigru')

word_emb_type = set()
if single_token_item:
    node_edge_emb_strategy = None
    if 'w2v' in emb_strategy:
        word_emb_type.add('w2v')

    if 'bert' in emb_strategy:
        word_emb_type.add('seq_bert')

    if 'bilstm' in emb_strategy:
        seq_info_encode_strategy = 'bilstm'
    elif 'bigru' in emb_strategy:
        seq_info_encode_strategy = 'bigru'
    else:
        seq_info_encode_strategy = 'none'
else:
    seq_info_encode_strategy = 'none'
    if 'w2v' in emb_strategy:
        word_emb_type.add('w2v')

    if 'bert' in emb_strategy:
        word_emb_type.add('node_edge_bert')

    if 'bilstm' in emb_strategy:
        node_edge_emb_strategy = 'bilstm'
    elif 'bigru' in emb_strategy:
        node_edge_emb_strategy = 'bigru'
    else:
        node_edge_emb_strategy = 'mean'

For instance, for single-token item, w2v_bilstm strategy means we first use word2vec embeddings to initialize each item, and then apply a BiLSTM encoder to encode the whole graph (assuming the node order reserves the sequential order in raw text). Compared to w2v_bilstm, the w2v_bert_bilstm strategy in addition applies the BERT encoder to the whole graph (i.e., sequential text), the concatenation of the BERT embedding and word2vec embedding instead of word2vec embedding will be fed into the BiLSTM encoder.

# single-token item graph
feat = []
token_ids = batch_gd.batch_node_features["token_id"]
if 'w2v' in self.word_emb_layers:
    word_feat = self.word_emb_layers['w2v'](token_ids).squeeze(-2)
    word_feat = dropout_fn(word_feat, self.word_dropout, shared_axes=[-2], training=self.training)
    feat.append(word_feat)

new_feat = feat
if 'seq_bert' in self.word_emb_layers:
    gd_list = from_batch(batch_gd)
    raw_tokens = [[gd.node_attributes[i]['token'] for i in range(gd.get_node_num())] for gd in gd_list]
    bert_feat = self.word_emb_layers['seq_bert'](raw_tokens)
    bert_feat = dropout_fn(bert_feat, self.bert_dropout, shared_axes=[-2], training=self.training)
    new_feat.append(bert_feat)

new_feat = torch.cat(new_feat, -1)
if self.seq_info_encode_layer is None:
    batch_gd.batch_node_features["node_feat"] = new_feat
else:
    rnn_state = self.seq_info_encode_layer(new_feat, torch.LongTensor(batch_gd._batch_num_nodes).to(batch_gd.device))
    if isinstance(rnn_state, (tuple, list)):
        rnn_state = rnn_state[0]

    batch_gd.batch_node_features["node_feat"] = rnn_state

For multi-token item, w2v_bilstm strategy means we first use the word2vec embeddings to initialize each token in the item, then apply a BiLSTM encoder to encode each item text. Compared to w2v_bilstm, the w2v_bert_bilstm strategy in addition applies the BERT encoder to each item text, the concatenation of the BERT embedding and word2vec embedding instead of word2vec embedding will be fed into the BiLSTM encoder.

# multi-token item graph
feat = []
token_ids = batch_gd.node_features["token_id"]
if 'w2v' in self.word_emb_layers:
    word_feat = self.word_emb_layers['w2v'](token_ids)
    word_feat = dropout_fn(word_feat, self.word_dropout, shared_axes=[-2], training=self.training)
    feat.append(word_feat)

if 'node_edge_bert' in self.word_emb_layers:
    input_data = [batch_gd.node_attributes[i]['token'].strip().split(' ') for i in range(batch_gd.get_node_num())]
    node_edge_bert_feat = self.word_emb_layers['node_edge_bert'](input_data)
    node_edge_bert_feat = dropout_fn(node_edge_bert_feat, self.bert_dropout, shared_axes=[-2], training=self.training)
    feat.append(node_edge_bert_feat)

if len(feat) > 0:
    feat = torch.cat(feat, dim=-1)
    node_token_lens = torch.clamp((token_ids != Vocab.PAD).sum(-1), min=1)
    feat = self.node_edge_emb_layer(feat, node_token_lens)
    if isinstance(feat, (tuple, list)):
        feat = feat[-1]

    feat = batch_gd.split_features(feat)

batch_gd.batch_node_features["node_feat"] = feat

Various embedding modules

Various embedding modules are provided in the library to support embedding construction. For instance, WordEmbedding class aims to convert the input word index sequence to the word embedding matrix. MeanEmbedding class simply computes the average embeddings. RNNEmbedding class applies the RNN network (e.g., GRU, LSTM, BiGRU, BiLSTM) to a sequence of word embeddings.

We will introduce BertEmbedding in more detail next. BertEmbedding class calls the Hugging Face Transformers APIs to compute the BERT embeddings for the input text. Transformer-based models like BERT have limit on the maximal sequence length. The BertEmbedding class can automaticall cut the long input sequence to multiple small chunks and call Transformers APIs for each of the small chunk, and then automtically merge their embeddings to obtain the embedding for the original long sequence. Below is the code piece showing the BertEmbedding class API. Users can specify max_seq_len and doc_stride to indicate the maximal sequence length and the stride (i.e., similar to the stride idea in ConvNet) when cutting long text into small chunks. In addition, instead of returning the last encoder layer as the output state, it returns the weighted average of all the encoder layer states as the output layer, as we find this works better in practice. Note the weight is a learnable parameter.

class BertEmbedding(nn.Module):
    def __init__(self, name='bert-base-uncased', max_seq_len=500, doc_stride=250, fix_emb=True, lower_case=True):
        super(BertEmbedding, self).__init__()
        self.bert_max_seq_len = max_seq_len
        self.bert_doc_stride = doc_stride
        self.fix_emb = fix_emb

        from transformers import BertModel
        from transformers import BertTokenizer
        print('[ Using pretrained BERT embeddings ]')
        self.bert_tokenizer = BertTokenizer.from_pretrained(name, do_lower_case=lower_case)
        self.bert_model = BertModel.from_pretrained(name)
        if fix_emb:
            print('[ Fix BERT layers ]')
            self.bert_model.eval()
            for param in self.bert_model.parameters():
                param.requires_grad = False
        else:
            print('[ Finetune BERT layers ]')
            self.bert_model.train()

        # compute weighted average over BERT layers
        self.logits_bert_layers = nn.Parameter(nn.init.xavier_uniform_(torch.Tensor(1, self.bert_model.config.num_hidden_layers)))