本文共 3991 字,大约阅读时间需要 13 分钟。
本文主要参考
主要介绍的模型是Cluster-GCN。Cluster-GCN提出:
1.利用图节点聚类算法将一个图的节点划分为 个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。 2.由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。 3.每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。 4.基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。尽管简单Cluster-GCN方法可以做到较其他方法更低的计算和内存复杂度,但它仍存在两个潜在问题:
1.图被分割后,一些边(公式(4)中的 部分)被移除,性能可能因此会受到影响。 2.图聚类算法倾向于将相似的节点聚集在一起。因此,单个簇中节点的类别分布可能与原始数据集不同,导致对梯度的估计有偏差。import torchimport torch.nn.functional as Ffrom torch.nn import ModuleListfrom tqdm import tqdmfrom torch_geometric.datasets import Reddit, Reddit2from torch_geometric.data import ClusterData, ClusterLoader, NeighborSamplerfrom torch_geometric.nn import SAGEConvdataset = Reddit('dataset/Reddit')data = dataset[0]data
cluster_data = ClusterData(data, num_parts=1000, recursive=False, save_dir=dataset.processed_dir)train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=4)subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=4)
class Net(torch.nn.Module): def __init__(self, in_channels, out_channels): super(Net, self).__init__() self.convs = ModuleList( [SAGEConv(in_channels, 128), SAGEConv(128, out_channels)]) def forward(self, x, edge_index): for i, conv in enumerate(self.convs): x = conv(x, edge_index) if i != len(self.convs) - 1: x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) return F.log_softmax(x, dim=-1) def inference(self, x_all): pbar = tqdm(total=x_all.size(0) * len(self.convs)) pbar.set_description('Evaluating') # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. for i, conv in enumerate(self.convs): xs = [] for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj.to(device) x = x_all[n_id].to(device) x_target = x[:size[1]] x = conv((x, x_target), edge_index) if i != len(self.convs) - 1: x = F.relu(x) xs.append(x.cpu()) pbar.update(batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_alldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = Net(dataset.num_features, dataset.num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=0.005)def train(): model.train() total_loss = total_nodes = 0 for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch.x, batch.edge_index) loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask]) loss.backward() optimizer.step() nodes = batch.train_mask.sum().item() total_loss += loss.item() * nodes total_nodes += nodes return total_loss / total_nodes@torch.no_grad()def test(): # Inference should be performed on the full graph. model.eval() out = model.inference(data.x) y_pred = out.argmax(dim=-1) accs = [] for mask in [data.train_mask, data.val_mask, data.test_mask]: correct = y_pred[mask].eq(data.y[mask]).sum().item() accs.append(correct / mask.sum().item()) return accsfor epoch in range(1, 31): loss = train() if epoch % 5 == 0: train_acc, val_acc, test_acc = test() print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' f'Val: {val_acc:.4f}, test: {test_acc:.4f}') else: print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
转载地址:http://kvten.baihongyu.com/