BatchingΒΆ

In practice, we usually need to convert a collection of small graph into a large graph where each original small graph is a connected component of the large graph. This operation is called batching in graph deep learning and is widely applied to improve computing efficiency.

GraphData provides interfaces for batching and unbatching graphs for training and inference. The to_batch() function takes a list of GraphData instances and returns a single GraphData which is the merged large graph. On the other hand, users may use from_batch() to decompose a large graph generated by merging small graphs into a list of GraphData.

The following code snippet shows an example:

g_list = []
batched_edges = []
graph_edges_list = []
# Build a number of graphs
for i in range(5):
    g = GraphData()
    g.add_nodes(10)
    for j in range(10):
        g.add_edge(src=j, tgt=(j + 1) % 10)
        batched_edges.append((i * 10 + j, i * 10 + ((j + 1) % 10)))
    g.node_features['idx'] = torch.ones(10) * i
    g.edge_features['idx'] = torch.ones(10) * i
    graph_edges_list.append(g.get_all_edges())
    g_list.append(g)

# Test to_batch
batch = to_batch(g_list)

target_batch_idx = []
for i in range(5):
    for j in range(10):
        target_batch_idx.append(i)

# Expected behaviors
assert batch.batch == target_batch_idx
assert batch.get_node_num() == 50
assert batch.get_all_edges() == batched_edges

# Un-batching
graph_list = from_batch(batch)

for i in range(len(graph_list)):
    g = graph_list[i]
    # Expected behaviors
    assert g.get_all_edges() == graph_edges_list[i]
    assert g.get_node_num() == 10
    assert torch.all(torch.eq(g.node_features['idx'], torch.ones(10) * i))
    assert torch.all(torch.eq(g.edge_features['idx'], torch.ones(10) * i))