我想实现一个无监督的GNN,这样它就可以标记我的节点。我想定义一个损失函数,它可以描述节点值和它的邻居值之间的关系。但是,在我训练之后,损失曲线只是一个噪音,似乎网络在训练过程中什么也没学到,因为参数值没有改变。谢谢你的帮助!
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import networkx as nx
import numpy as np
from time import time
import random
from itertools import chain, islice
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
setup_seed(3)
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class GCN_Net(torch.nn.Module):
def __init__(self, num_features, num_classes, nerouns, dropout=0.1):
super(GCN_Net, self).__init__()
self.dropout = dropout
self.conv1 = GCNConv(num_features, nerouns)
self.conv2 = GCNConv(nerouns, 2*nerouns)
self.linear = torch.nn.Linear(2*nerouns, num_classes)
self.softmax = torch.nn.Softmax()
def forward(self, data):
h = self.conv1(data.x, data.edge_index)
h = torch.relu(h)
h = F.dropout(h, p=self.dropout)
h = self.conv2(h, data.edge_index)
h = self.linear(h)
h = self.softmax(h)
return h
def quardaic_loss(a,b):
return (a-b)**2
def loss_function(outputs, edges, func=quardaic_loss):
loss = []
for n1, n2 in edges:
loss.append((outputs[n1].item() - outputs[n2].item())**2)
return torch.sum(torch.tensor(loss), dtype=float)
def data_transformation(graph):
nodes = list(graph.nodes())
edge_idx = [[], []]
for node in nodes:
for neighbor in [n for n in graph.neighbors(node)]:
edge_idx[0].append(node)
edge_idx[0].append(neighbor)
edge_idx[1].append(neighbor)
edge_idx[1].append(node)
edge_index = torch.tensor(edge_idx, dtype=torch.long)
return edge_index
def create_net(nodes_num, gnn_hypers, opt_params, torch_device, torch_dtype):
num_features = gnn_hypers['num_features']
number_classes = gnn_hypers['number_classes']
dropout = gnn_hypers['dropout']
neurons = gnn_hypers['neurons']
embed = nn.Embedding(nodes_num, num_features)
embed = embed.type(torch_dtype).to(torch_device)
net = GCN_Net(num_features, number_classes, neurons, dropout)
net = net.type(torch_dtype).to(torch_device)
params = chain(net.parameters(), embed.parameters())
optimizer = torch.optim.Adam(params, **opt_params)
return net, embed, optimizer
def train(graph, net, embed, num_epoch,
optimizer, loss_function, device, max_state=15, tol=5, patience=4):
edges = list(graph.edges())
torch.manual_seed(666)
x = embed.weight
edge_index = data_transformation(graph)
data = Data(x=x, edge_index=edge_index.contiguous())
data = data.to(device)
prevloss = len(list(graph.nodes())) * max_state**2
best_loss = len(list(graph.nodes())) * max_state**2
start_time = time()
no_improve_count = 0
# print("Initial loss is {}".format(prevloss))
# print("Tranining starts...")
losses = []
for epoch in range(num_epoch):
out = net(data)
out = torch.argmax(out, 1) + 1
loss = loss_function(out, edges)
loss.requires_grad_(True)
loss_ = loss.detach().item()
losses.append(loss_)
best_loss = min(best_loss, loss_)
if abs(loss_ - prevloss) < tol or loss_ - prevloss > 0:
no_improve_count += 1
else:
no_improve_count = 0
if no_improve_count >= patience:
print("Early stopping at epoch {} with patience {}".format(epoch, patience))
break
if epoch % 100 == 0:
print("The loss after epoch {} is {}".format(epoch, loss_))
prevloss = loss_
optimizer.zero_grad()
loss.backward()
optimizer.step()
edge_index = data_transformation(graph)
data = Data(x=x, edge_index=edge_index.contiguous())
data = data.to(device)
print("The training time is {} for total {} epochs.".format(time() - start_time, epoch+1))
print("The final loss is {}".format(loss_))
print("The best loss is {}".format(best_loss))
return best_loss, losses
nodes_num = 100
g = nx.erdos_renyi_graph(nodes_num, p=0.1)
# print("The obs_0 is {}".format(observations[0]))
gnn_hypers = {'num_features': 10, 'number_classes': 10, "neurons": 500, 'dropout': 0.1}
opt_params = {'lr': 0.01}
torch_device = TORCH_DEVICE
torch_dtype = torch.float
num_epoch = 100
dim_embedding = 10
net, embed, optimizer = create_net(nodes_num, gnn_hypers, opt_params, torch_device, torch_dtype)
train(g, net, embed, num_epoch,
optimizer, loss_function, torch_device, 10)
字符串
1条答案
按热度按时间vm0i2vca1#
我认为问题在于:
第一个月
调用
.item()
返回一个普通的python浮点数。这会破坏autograd渐变链。