Chapter 5.2 Standard Tree Decoder

The output of many NLP applications (i.e., semantic parsing, code generation, and math word problem) contain structural information. For example, the output in math word problem is a mathematical equation, which can be expressed naturally by the data structure of the tree. To model these kinds of outputs, tree decoders are widely adopted. Tree decoders can be divided into two main parts: DFS (depth-first search) based tree decoder, and BFS (breadth-first search) based tree decoder. We mainly implement BFS based tree decoder here. Specifically, we give a simple example on how StdTreeDecoder is initialized as follows,

import torch
import torch.nn as nn
from graph4nlp.pytorch.modules.config import get_basic_args
from graph4nlp.pytorch.modules.utils.config_utils import update_values, get_yaml_config

from graph4nlp.pytorch.datasets.jobs import JobsDatasetForTree
from graph4nlp.pytorch.modules.utils.tree_utils import Vocab
from graph4nlp.pytorch.modules.prediction.generation.TreeBasedDecoder import StdTreeDecoder

# get your vocab_model, batch_graph, and tgt_tree_batch
dec_word_emb = nn.Embedding(out_vocab.embeddings.shape[0],
                             out_vocab.embeddings.shape[1],
                             padding_idx=0,
                             _weight=torch.from_numpy(out_vocab.embeddings).float())

decoder = StdTreeDecoder(attn_type="uniform", embeddings=dec_word_emb, enc_hidden_size=300,
                         dec_emb_size=out_vocab.embedding_dims, dec_hidden_size=300,
                         output_size=out_vocab.vocab_size,
                         criterion=nn.NLLLoss(ignore_index=0, reduction='none'),
                         teacher_force_ratio=1.0, use_copy=False, max_dec_seq_length=50,
                         max_dec_tree_depth=5, tgt_vocab=out_vocab)

predicted = decoder(batch_graph=batch_graph, tgt_tree_batch=tgt_tree_batch)

Implementation details

What is the tree decoding process

In the BFS-based tree decoding approach, we represent all subtrees as non-terminal nodes. Then we divide the whole tree structure into multiple “sequences” from top to bottom according to the non-terminal nodes, which is shown in code below,

def get_dec_batch(dec_tree_batch, batch_size, device, form_manager):
    queue_tree = {}
    for i in range(1, batch_size+1):
        queue_tree[i] = []
        queue_tree[i].append({"tree": dec_tree_batch[i-1], "parent": 0, "child_index": 1})

    cur_index, max_index = 1, 1
    dec_batch = {}
    # max_index: the max number of sequence decoder in one batch
    while (cur_index <= max_index):
        max_w_len = -1
        batch_w_list = []
        for i in range(1, batch_size+1):
            w_list = []
            if (cur_index <= len(queue_tree[i])):
                t = queue_tree[i][cur_index - 1]["tree"]
                for ic in range(t.num_children):
                    if isinstance(t.children[ic], Tree):
                        w_list.append(4)
                        queue_tree[i].append({"tree": t.children[ic], "parent": cur_index, "child_index": ic + 1})
                    else:
                        w_list.append(t.children[ic])
                if len(queue_tree[i]) > max_index:
                    max_index = len(queue_tree[i])
            if len(w_list) > max_w_len:
                max_w_len = len(w_list)
            batch_w_list.append(w_list)
        dec_batch[cur_index] = torch.zeros(
            (batch_size, max_w_len + 2), dtype=torch.long)
        for i in range(batch_size):
            w_list = batch_w_list[i]
            if len(w_list) > 0:
                for j in range(len(w_list)):
                    dec_batch[cur_index][i][j+1] = w_list[j]
                if cur_index == 1:
                    dec_batch[cur_index][i][0] = 1
                else:
                    dec_batch[cur_index][i][0] = form_manager.get_symbol_idx('(')
                dec_batch[cur_index][i][len(w_list) + 1] = 2
        dec_batch[cur_index] = to_cuda(dec_batch[cur_index], device)
        cur_index += 1
    return dec_batch, queue_tree, max_index

We then use sequence decoding to generate the tree structure in order. And for each sequence decoding process, we will feed the embedding of its parent node and sibling node as auxiliary input.

The figure below is an example for tree decoding process, where nodes like S1 , S2 stand for sub-tree nodes, and once a sub-tree node generated, decoder will start a new branch for a new descendant decoding process. The number stands for the order of different branching decoding processes.

../../_images/tree_decoding.jpg

And we give a rough code snippet about how tree decoding is completed.

cur_index = 0
while (cur_index <= max_index):
    if cur_index > max_dec_tree_depth:
        break
    ...
    # get parent and sibling embeddings.
    # do sequence decoding.
    ...

    cur_index = cur_index + 1

Where max_index is the number of non-terminal nodes and max_dec_tree_depth is the maximum number of non-terminal nodes allowed.

Copy and separate attention

StdTreeDecoder employ copy and separate attention mechanism to enhance the overall Graph2Tree model. We show how copy is used in StdTreeDecoder here. Both StdTreeDecoder and StdRNNDecoder use separate attention mechanism, it has been explained in Chapter 5.1 Standard RNN Decoder. So we will not go into detail here.

For copy mechanism, it helps model to copy words directly from input, and computed as, \(p(w) = p_{gen} p_{softmax}(w) + (1 - p_{gen}) p_{copy}(w)\).

We refer to the implement of pointer-network. Technically, for a certain mini-batch graphdata, we firstly extend the original vocabulary to a full-vocabulary containing all words (including out-of-vocabulary (oov) words) in the mini-batch:

oov_dict = copy.deepcopy(src_vocab)
token_matrix = []
for n in batch_graph.node_attributes:
    node_token = n['token']
    # Pick out all out-of-vocabulary (oov) words in the mini-batch graphdata.
    if (n.get('type') == None or n.get('type') == 0) and oov_dict.get_symbol_idx(
            node_token) == oov_dict.get_symbol_idx(oov_dict.unk_token):
        # Add them into oov vocab model.
        oov_dict.add_symbol(node_token)
    token_matrix.append(oov_dict.get_symbol_idx(node_token))
return oov_dict

After that, the decoder learns the conditional probability of an output sequence with elements that are discrete tokens corresponding to positions in an input sequence. Code snippets as follows help with how it works.

if self.use_copy:
    pgen_collect = [dec_emb, hidden, attn_ptr]

    # the probability of copying a word from the source
    prob_ptr = torch.sigmoid(self.ptr(torch.cat(pgen_collect, -1)))

    # the probability of generating a word over the standard softmax on vocabulary model.
    prob_gen = 1 - prob_ptr
    gen_output = torch.softmax(decoder_output, dim=-1)

    ret = prob_gen * gen_output
    need_pad_length = oov_dict.get_vocab_size() - self.vocab.get_vocab_size()
    output = torch.cat((ret, ret.new_zeros((batch_size, need_pad_length))), dim=1)

    # attention scores
    ptr_output = dec_attn_scores
    output.scatter_add_(1, src_seq, prob_ptr * ptr_output)
    decoder_output = output
else:
    decoder_output = torch.softmax(decoder_output, dim=-1)

The returned decoder_output is a distribution over the extend dictionary oov_dict if copy is adopted. Users can set use_copy to True to use this feature. And the oov vocabulary must be passed when utilizing it.