博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
图神经网络task_05
阅读量:3904 次
发布时间:2019-05-23

本文共 3991 字,大约阅读时间需要 13 分钟。

超大图上的表征学习

本文主要参考

主要介绍的模型是Cluster-GCN。
在这里插入图片描述

模型主体思想

Cluster-GCN提出:

1.利用图节点聚类算法将一个图的节点划分为 个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
2.由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
3.每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
4.基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。

缓解邻域扩展问题

在这里插入图片描述

在这里插入图片描述

基于有效的聚类算法生成多个簇

在这里插入图片描述

分簇后的类别不平衡问题

尽管简单Cluster-GCN方法可以做到较其他方法更低的计算和内存复杂度,但它仍存在两个潜在问题:

1.图被分割后,一些边(公式(4)中的 部分)被移除,性能可能因此会受到影响。
2.图聚类算法倾向于将相似的节点聚集在一起。因此,单个簇中节点的类别分布可能与原始数据集不同,导致对梯度的估计有偏差。

解决方案(随机多簇方法)

在这里插入图片描述

算法流程

在这里插入图片描述

作业:尝试将数据集切分成不同数量的簇进行实验,然后观察结果并进行比较。

读取数据

import torchimport torch.nn.functional as Ffrom torch.nn import ModuleListfrom tqdm import tqdmfrom torch_geometric.datasets import Reddit, Reddit2from torch_geometric.data import ClusterData, ClusterLoader, NeighborSamplerfrom torch_geometric.nn import SAGEConvdataset = Reddit('dataset/Reddit')data = dataset[0]data

在这里插入图片描述

进行分簇等操作(分为1000簇)

cluster_data = ClusterData(data, num_parts=1000, recursive=False,                           save_dir=dataset.processed_dir)train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,                             num_workers=4)subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024,                                  shuffle=False, num_workers=4)
class Net(torch.nn.Module):    def __init__(self, in_channels, out_channels):        super(Net, self).__init__()        self.convs = ModuleList(            [SAGEConv(in_channels, 128),             SAGEConv(128, out_channels)])    def forward(self, x, edge_index):        for i, conv in enumerate(self.convs):            x = conv(x, edge_index)            if i != len(self.convs) - 1:                x = F.relu(x)                x = F.dropout(x, p=0.5, training=self.training)        return F.log_softmax(x, dim=-1)    def inference(self, x_all):        pbar = tqdm(total=x_all.size(0) * len(self.convs))        pbar.set_description('Evaluating')        # Compute representations of nodes layer by layer, using *all*        # available edges. This leads to faster computation in contrast to        # immediately computing the final representations of each batch.        for i, conv in enumerate(self.convs):            xs = []            for batch_size, n_id, adj in subgraph_loader:                edge_index, _, size = adj.to(device)                x = x_all[n_id].to(device)                x_target = x[:size[1]]                x = conv((x, x_target), edge_index)                if i != len(self.convs) - 1:                    x = F.relu(x)                xs.append(x.cpu())                pbar.update(batch_size)            x_all = torch.cat(xs, dim=0)        pbar.close()        return x_alldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = Net(dataset.num_features, dataset.num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=0.005)def train():    model.train()    total_loss = total_nodes = 0    for batch in train_loader:        batch = batch.to(device)        optimizer.zero_grad()        out = model(batch.x, batch.edge_index)        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])        loss.backward()        optimizer.step()        nodes = batch.train_mask.sum().item()        total_loss += loss.item() * nodes        total_nodes += nodes    return total_loss / total_nodes@torch.no_grad()def test():  # Inference should be performed on the full graph.    model.eval()    out = model.inference(data.x)    y_pred = out.argmax(dim=-1)    accs = []    for mask in [data.train_mask, data.val_mask, data.test_mask]:        correct = y_pred[mask].eq(data.y[mask]).sum().item()        accs.append(correct / mask.sum().item())    return accsfor epoch in range(1, 31):    loss = train()    if epoch % 5 == 0:        train_acc, val_acc, test_acc = test()        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '              f'Val: {val_acc:.4f}, test: {test_acc:.4f}')    else:        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

在这里插入图片描述

转载地址:http://kvten.baihongyu.com/

你可能感兴趣的文章
467. 环绕字符串中唯一的子字符串
查看>>
468. 验证IP地址
查看>>
474. 一和零
查看>>
486. 预测赢家
查看>>
494. 目标和
查看>>
520. 检测大写字母
查看>>
数据处理和训练模型的技巧
查看>>
vb 中如何做同步 异步?
查看>>
geturl
查看>>
李建忠,设计模式教程.笔记061220
查看>>
李建忠,设计模式教程.笔记061221
查看>>
关于sizeof
查看>>
windows 核心编程笔记.070301
查看>>
WINDOWS核心编程笔记 070303
查看>>
终于解决了交叉表左上角,每页都显示的问题.
查看>>
windows核心编程 070309
查看>>
哈,又解决水晶报表的一个难题
查看>>
VC Ini文件处理
查看>>
一直误解sql事务的用法.
查看>>
转:利用C#实现分布式数据库查询
查看>>