Text Classification Tutorial

Introduction

In this tutorial demo, we will use the Graph4NLP library to build a GNN-based text classification model. The model consists of

  • graph construction module (e.g., dependency based static graph)

  • graph embedding module (e.g., Bi-Fuse GraphSAGE)

  • predictoin module (e.g., graph pooling + MLP classifier)

We will use the built-in module APIs to build the model, and evaluate it on the TREC dataset. The full example can be downloaded from text classification notebook.

Environment setup

Please follow the instructions here to set up the environment.

Build the text classifier

Let’s first build the GNN-based text classifier which contains three major components including graph construction module, graph embedding module and graph prediction module.

For graph construction module, the Graph4NLP library provides built-in APIs to support both static graph construction methods (e.g., dependency graph, constituency graph, IE graph) and dynamic graph construction methods (e.g., node embedding based graph, node embedding based refined graph). When calling the graph construction API, users should also specify the embedding style (e.g., word2vec, BiLSTM, BERT) to initalize the node/edge embeddings. Both single-token and multi-token node/edge graphs are supported.

For graph embedding module, the Graph4NLP library provides builti-in APIs to support both undirectional and bidirectinal versions for common GNNs such as GCN, GraphSAGE, GAT and GGNN.

For graph prediction module, the Graph4NLP library provides a high-level graph classification prediction module which consists of a graph pooling component (e.g., average pooling, max pooling) and a multilayer perceptron (MLP).

class TextClassifier(nn.Module):
    def __init__(self, vocab, label_model, config):
        super(TextClassifier, self).__init__()
        self.config = config
        self.vocab = vocab
        self.label_model = label_model

        # Specify embedding style to initialize node/edge embeddings
        embedding_style = {'single_token_item': True if config['graph_type'] != 'ie' else False,
                            'emb_strategy': config.get('emb_strategy', 'w2v_bilstm'),
                            'num_rnn_layers': 1,
                            'bert_model_name': config.get('bert_model_name', 'bert-base-uncased'),
                            'bert_lower_case': True
                           }

        assert not (config['graph_type'] in ('node_emb', 'node_emb_refined') and config['gnn'] == 'gat'), \
                                'dynamic graph construction does not support GAT'

        use_edge_weight = False


        # Set up graph construction module
        if config['graph_type'] == 'dependency':
            self.graph_topology = DependencyBasedGraphConstruction(embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab, hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'], rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'], fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'constituency':
            self.graph_topology = ConstituencyBasedGraphConstruction(embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab, hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'], rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'], fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'ie':
            self.graph_topology = IEBasedGraphConstruction(embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab, hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'], rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'], fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'node_emb':
            self.graph_topology = NodeEmbeddingBasedGraphConstruction(vocab.in_word_vocab,
                                   embedding_style, sim_metric_type=config['gl_metric_type'],
                                   num_heads=config['gl_num_heads'], top_k_neigh=config['gl_top_k'],
                                   epsilon_neigh=config['gl_epsilon'], smoothness_ratio=config['gl_smoothness_ratio'],
                                   connectivity_ratio=config['gl_connectivity_ratio'], sparsity_ratio=config['gl_sparsity_ratio'],
                                   input_size=config['num_hidden'], hidden_size=config['gl_num_hidden'],
                                   fix_word_emb=not config['no_fix_word_emb'], fix_bert_emb=not config.get('no_fix_bert_emb', False),
                                   word_dropout=config['word_dropout'], rnn_dropout=config['rnn_dropout'])
            use_edge_weight = True
        elif config['graph_type'] == 'node_emb_refined':
            self.graph_topology = NodeEmbeddingBasedRefinedGraphConstruction(vocab.in_word_vocab,
                                    embedding_style, config['init_adj_alpha'],
                                    sim_metric_type=config['gl_metric_type'], num_heads=config['gl_num_heads'],
                                    top_k_neigh=config['gl_top_k'], epsilon_neigh=config['gl_epsilon'],
                                    smoothness_ratio=config['gl_smoothness_ratio'], connectivity_ratio=config['gl_connectivity_ratio'],
                                    sparsity_ratio=config['gl_sparsity_ratio'], input_size=config['num_hidden'],
                                    hidden_size=config['gl_num_hidden'], fix_word_emb=not config['no_fix_word_emb'],
                                    fix_bert_emb=not config.get('no_fix_bert_emb', False),
                                    word_dropout=config['word_dropout'], rnn_dropout=config['rnn_dropout'])
            use_edge_weight = True
        else:
            raise RuntimeError('Unknown graph_type: {}'.format(config['graph_type']))

        if 'w2v' in self.graph_topology.embedding_layer.word_emb_layers:
            self.word_emb = self.graph_topology.embedding_layer.word_emb_layers['w2v'].word_emb_layer
        else:
            self.word_emb = WordEmbedding(self.vocab.in_word_vocab.embeddings.shape[0],
                            self.vocab.in_word_vocab.embeddings.shape[1], pretrained_word_emb=self.vocab.in_word_vocab.embeddings,
                            fix_emb=not config['no_fix_word_emb'], device=config['device']).word_emb_layer


        # Set up graph embedding module
        if config['gnn'] == 'gat':
            heads = [config['gat_num_heads']] * (config['gnn_num_layers'] - 1) + [config['gat_num_out_heads']]
            self.gnn = GAT(config['gnn_num_layers'], config['num_hidden'], config['num_hidden'], config['num_hidden'],
                        heads, direction_option=config['gnn_direction_option'], feat_drop=config['gnn_dropout'],
                        attn_drop=config['gat_attn_dropout'], negative_slope=config['gat_negative_slope'],
                        residual=config['gat_residual'], activation=F.elu)
        elif config['gnn'] == 'graphsage':
            self.gnn = GraphSAGE(config['gnn_num_layers'], config['num_hidden'], config['num_hidden'], config['num_hidden'],
                        config['graphsage_aggreagte_type'], direction_option=config['gnn_direction_option'], feat_drop=config['gnn_dropout'],
                        bias=True, norm=None, activation=F.relu, use_edge_weight=use_edge_weight)
        elif config['gnn'] == 'ggnn':
            self.gnn = GGNN(config['gnn_num_layers'], config['num_hidden'], config['num_hidden'], config['num_hidden'],
                        feat_drop=config['gnn_dropout'], direction_option=config['gnn_direction_option'], bias=True, use_edge_weight=use_edge_weight)
        else:
            raise RuntimeError('Unknown gnn type: {}'.format(config['gnn']))


        # Set up graph prediction module
        self.clf = FeedForwardNN(2 * config['num_hidden'] if config['gnn_direction_option'] == 'bi_sep' else config['num_hidden'],
                    config['num_classes'], [config['num_hidden']], graph_pool_type=config['graph_pooling'],
                    dim=config['num_hidden'], use_linear_proj=config['max_pool_linear_proj'])

        self.loss = GeneralLoss('CrossEntropy')


    def forward(self, graph_list, tgt=None, require_loss=True):
        # build graph topology
        batch_gd = self.graph_topology(graph_list)

        # run GNN encoder
        self.gnn(batch_gd)

        # run graph classifier
        self.clf(batch_gd)
        logits = batch_gd.graph_attributes['logits']

        if require_loss:
            loss = self.loss(logits, tgt)
            return logits, loss
        else:
            return logits

    @classmethod
    def load_checkpoint(cls, model_path):
        return torch.load(model_path)

Build the model handler

Next, let’s build a model handler which will do a bunch of things including setting up dataloader, model, optimizer, evaluation metrics, train/val/test loops, and so on.

When setting up the dataloader, users will need to call the dataset API which will preprocess the data, e.g., calling the graph construction module, building the vocabulary, tensorizing the data. Users will need to specify the graph construction type when calling the dataset API.

Users can build their customized dataset APIs by inheriting our low-level dataset APIs. We provide low-level dataset APIs to support various scenarios (e.g., Text2Label, Sequence2Labeling, Text2Text, Text2Tree, DoubleText2Text).

class ModelHandler:
    def __init__(self, config):
        super(ModelHandler, self).__init__()
        self.config = config
        self.logger = Logger(self.config['out_dir'], config={k:v for k, v in self.config.items() if k != 'device'}, overwrite=True)
        self.logger.write(self.config['out_dir'])
        self._build_device()
        self._build_dataloader()
        self._build_model()
        self._build_optimizer()
        self._build_evaluation()

    def _build_device(self):
        if not self.config['no_cuda'] and torch.cuda.is_available():
            print('[ Using CUDA ]')
            self.config['device'] = torch.device('cuda' if self.config['gpu'] < 0 else 'cuda:%d' % self.config['gpu'])
            torch.cuda.manual_seed(self.config['seed'])
            torch.cuda.manual_seed_all(self.config['seed'])
            torch.backends.cudnn.deterministic = True
            cudnn.benchmark = False
        else:
            self.config['device'] = torch.device('cpu')

    def _build_dataloader(self):
        dynamic_init_topology_builder = None
        if self.config['graph_type'] == 'dependency':
            topology_builder = DependencyBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'tailhead'
        elif self.config['graph_type'] == 'constituency':
            topology_builder = ConstituencyBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'tailhead'
        elif self.config['graph_type'] == 'ie':
            topology_builder = IEBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'global'
        elif self.config['graph_type'] == 'node_emb':
            topology_builder = NodeEmbeddingBasedGraphConstruction
            graph_type = 'dynamic'
            merge_strategy = None
        elif self.config['graph_type'] == 'node_emb_refined':
            topology_builder = NodeEmbeddingBasedRefinedGraphConstruction
            graph_type = 'dynamic'
            merge_strategy = 'tailhead'

            if self.config['init_graph_type'] == 'line':
                dynamic_init_topology_builder = None
            elif self.config['init_graph_type'] == 'dependency':
                dynamic_init_topology_builder = DependencyBasedGraphConstruction
            elif self.config['init_graph_type'] == 'constituency':
                dynamic_init_topology_builder = ConstituencyBasedGraphConstruction
            elif self.config['init_graph_type'] == 'ie':
                merge_strategy = 'global'
                dynamic_init_topology_builder = IEBasedGraphConstruction
            else:
                raise RuntimeError('Define your own dynamic_init_topology_builder')
        else:
            raise RuntimeError('Unknown graph_type: {}'.format(self.config['graph_type']))

        topology_subdir = '{}_graph'.format(self.config['graph_type'])
        if self.config['graph_type'] == 'node_emb_refined':
            topology_subdir += '_{}'.format(self.config['init_graph_type'])


        # Call the TREC dataset API
        dataset = TrecDataset(root_dir=self.config.get('root_dir', self.config['root_data_dir']),
                              pretrained_word_emb_name=self.config.get('pretrained_word_emb_name', "840B"),
                              merge_strategy=merge_strategy, seed=self.config['seed'], thread_number=4,
                              port=9000, timeout=15000, word_emb_size=300, graph_type=graph_type,
                              topology_builder=topology_builder, topology_subdir=topology_subdir,
                              dynamic_graph_type=self.config['graph_type'] if \
                                  self.config['graph_type'] in ('node_emb', 'node_emb_refined') else None,
                              dynamic_init_topology_builder=dynamic_init_topology_builder,
                              dynamic_init_topology_aux_args={'dummy_param': 0})

        self.train_dataloader = DataLoader(dataset.train, batch_size=self.config['batch_size'], shuffle=True,
                                           num_workers=self.config['num_workers'], collate_fn=dataset.collate_fn)
        if hasattr(dataset, 'val')==False:
            dataset.val = dataset.test
        self.val_dataloader = DataLoader(dataset.val, batch_size=self.config['batch_size'], shuffle=False,
                                          num_workers=self.config['num_workers'], collate_fn=dataset.collate_fn)
        self.test_dataloader = DataLoader(dataset.test, batch_size=self.config['batch_size'], shuffle=False,
                                          num_workers=self.config['num_workers'], collate_fn=dataset.collate_fn)
        self.vocab = dataset.vocab_model
        self.label_model = dataset.label_model
        self.config['num_classes'] = self.label_model.num_classes
        self.num_train = len(dataset.train)
        self.num_val = len(dataset.val)
        self.num_test = len(dataset.test)
        print('Train size: {}, Val size: {}, Test size: {}'
            .format(self.num_train, self.num_val, self.num_test))
        self.logger.write('Train size: {}, Val size: {}, Test size: {}'
            .format(self.num_train, self.num_val, self.num_test))

    def _build_model(self):
        self.model = TextClassifier(self.vocab, self.label_model, self.config).to(self.config['device'])

    def _build_optimizer(self):
        parameters = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(parameters, lr=self.config['lr'])
        self.stopper = EarlyStopping(os.path.join(self.config['out_dir'], Constants._SAVED_WEIGHTS_FILE), patience=self.config['patience'])
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_reduce_factor'], \
            patience=self.config['lr_patience'], verbose=True)

    def _build_evaluation(self):
        self.metric = Accuracy(['accuracy'])

    def train(self):
        dur = []
        for epoch in range(self.config['epochs']):
            self.model.train()
            train_loss = []
            train_acc = []
            t0 = time.time()
            for i, data in enumerate(self.train_dataloader):
                tgt = data['tgt_tensor'].to(self.config['device'])
                data['graph_data'] = data['graph_data'].to(self.config['device'])
                logits, loss = self.model(data['graph_data'], tgt, require_loss=True)

                # add graph regularization loss if available
                if data['graph_data'].graph_attributes.get('graph_reg', None) is not None:
                    loss = loss + data['graph_data'].graph_attributes['graph_reg']

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                train_loss.append(loss.item())

                pred = torch.max(logits, dim=-1)[1].cpu()
                train_acc.append(self.metric.calculate_scores(ground_truth=tgt.cpu(), predict=pred.cpu(), zero_division=0)[0])
                dur.append(time.time() - t0)

            val_acc = self.evaluate(self.val_dataloader)
            self.scheduler.step(val_acc)
            print('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.
              format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))
            self.logger.write('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.
                        format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))

            if self.stopper.step(val_acc, self.model):
                break

        return self.stopper.best_score

    def evaluate(self, dataloader):
        self.model.eval()
        with torch.no_grad():
            pred_collect = []
            gt_collect = []
            for i, data in enumerate(dataloader):
                tgt = data['tgt_tensor'].to(self.config['device'])
                data['graph_data'] = data['graph_data'].to(self.config["device"])
                logits = self.model(data['graph_data'], require_loss=False)
                pred_collect.append(logits)
                gt_collect.append(tgt)

            pred_collect = torch.max(torch.cat(pred_collect, 0), dim=-1)[1].cpu()
            gt_collect = torch.cat(gt_collect, 0).cpu()
            score = self.metric.calculate_scores(ground_truth=gt_collect, predict=pred_collect, zero_division=0)[0]

            return score

    def test(self):
        # restored best saved model
        self.model = TextClassifier.load_checkpoint(self.stopper.save_model_path)

        t0 = time.time()
        acc = self.evaluate(self.test_dataloader)
        dur = time.time() - t0
        print('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.
          format(self.num_test, dur, acc))
        self.logger.write('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.
          format(self.num_test, dur, acc))

        return acc

Run the model

runner = ModelHandler(config)
val_acc = runner.train()
test_acc = runner.test()
out/trec/graphsage_bi_fuse_dependency_ckpt_1628651059.35833
Loading pre-built label mappings stored in ../data/trec/processed/dependency_graph/label.pt
Train size: 5452, Val size: 500, Test size: 500
[ Fix word embeddings ]
Epoch: [1 / 500] | Time: 14.28s | Loss: 1.1777 | Train Acc: 0.5249 | Val Acc: 0.7740
Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1628651059.35833/params.saved
Epoch: [2 / 500] | Time: 13.17s | Loss: 0.6613 | Train Acc: 0.7596 | Val Acc: 0.8280
Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1628651059.35833/params.saved
......