pytorch 在无监督的GNN中,为什么我的参数没有更新,为什么损失只是噪音

ljo96ir5  于 2024-01-09  发布在  其他
关注(0)|答案(1)|浏览(173)

我想实现一个无监督的GNN,这样它就可以标记我的节点。我想定义一个损失函数,它可以描述节点值和它的邻居值之间的关系。但是,在我训练之后,损失曲线只是一个噪音,似乎网络在训练过程中什么也没学到,因为参数值没有改变。谢谢你的帮助!

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

import networkx as nx
import numpy as np
from time import time
import random
from itertools import chain, islice

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

setup_seed(3)
    
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GCN_Net(torch.nn.Module):
    def __init__(self, num_features, num_classes, nerouns, dropout=0.1):
        super(GCN_Net, self).__init__()
        self.dropout = dropout
        self.conv1 = GCNConv(num_features, nerouns)
        self.conv2 = GCNConv(nerouns, 2*nerouns)

        self.linear = torch.nn.Linear(2*nerouns, num_classes)
        self.softmax = torch.nn.Softmax()
        
    def forward(self, data):
        h = self.conv1(data.x, data.edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=self.dropout)
        
        h = self.conv2(h, data.edge_index)

        h = self.linear(h)
        h = self.softmax(h)

        return h

def quardaic_loss(a,b):
    return (a-b)**2

def loss_function(outputs, edges, func=quardaic_loss):
    loss = []

    for n1, n2 in edges:
        loss.append((outputs[n1].item() - outputs[n2].item())**2)

    return torch.sum(torch.tensor(loss), dtype=float)

def data_transformation(graph):
    nodes = list(graph.nodes())
    edge_idx = [[], []]
    
    for node in nodes:
        for neighbor in [n for n in graph.neighbors(node)]:
            edge_idx[0].append(node)
            edge_idx[0].append(neighbor)
            edge_idx[1].append(neighbor)
            edge_idx[1].append(node)
                
    edge_index = torch.tensor(edge_idx, dtype=torch.long)
    return edge_index

def create_net(nodes_num, gnn_hypers, opt_params, torch_device, torch_dtype):

    num_features = gnn_hypers['num_features']
    number_classes = gnn_hypers['number_classes']
    dropout = gnn_hypers['dropout']
    neurons = gnn_hypers['neurons']
    
    embed = nn.Embedding(nodes_num, num_features)
    embed = embed.type(torch_dtype).to(torch_device)
    
    net = GCN_Net(num_features, number_classes, neurons, dropout)
    net = net.type(torch_dtype).to(torch_device)

    params = chain(net.parameters(), embed.parameters())
    optimizer = torch.optim.Adam(params, **opt_params)
    return net, embed, optimizer
    

    
def train(graph, net, embed, num_epoch, 
        optimizer, loss_function, device, max_state=15, tol=5, patience=4):
    
    edges = list(graph.edges())
    
    torch.manual_seed(666)
    
    x = embed.weight

    edge_index = data_transformation(graph)
    data = Data(x=x, edge_index=edge_index.contiguous())
    data = data.to(device)
    
    prevloss = len(list(graph.nodes())) * max_state**2
    best_loss = len(list(graph.nodes())) * max_state**2
    start_time = time()
    
    no_improve_count = 0
    # print("Initial loss is {}".format(prevloss))
    # print("Tranining starts...")
    losses = []
    for epoch in range(num_epoch):
        out = net(data)
        out = torch.argmax(out, 1) + 1
        loss = loss_function(out, edges)
        loss.requires_grad_(True)
        loss_ = loss.detach().item()
        losses.append(loss_)
        best_loss = min(best_loss, loss_)
        
        if abs(loss_ - prevloss) < tol or loss_ - prevloss > 0:
            no_improve_count += 1
        else:
            no_improve_count = 0    
        
        if no_improve_count >= patience:
            print("Early stopping at epoch {} with patience {}".format(epoch, patience))
            break
        
        if epoch % 100 == 0:
            print("The loss after epoch {} is {}".format(epoch, loss_))

        prevloss = loss_
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        edge_index = data_transformation(graph)
        data = Data(x=x, edge_index=edge_index.contiguous())
        data = data.to(device)
        
    print("The training time is {} for total {} epochs.".format(time() - start_time, epoch+1))
    print("The final loss is {}".format(loss_))
    print("The best loss is {}".format(best_loss))

    return best_loss, losses

nodes_num = 100
g = nx.erdos_renyi_graph(nodes_num, p=0.1)

# print("The obs_0 is {}".format(observations[0]))

gnn_hypers = {'num_features': 10, 'number_classes': 10, "neurons": 500, 'dropout': 0.1}
opt_params = {'lr': 0.01}
torch_device = TORCH_DEVICE
torch_dtype = torch.float
num_epoch = 100
dim_embedding = 10

net, embed, optimizer = create_net(nodes_num, gnn_hypers, opt_params, torch_device, torch_dtype)

train(g, net, embed, num_epoch, 
      optimizer, loss_function, torch_device, 10)

字符串

vm0i2vca

vm0i2vca1#

我认为问题在于:
第一个月
调用.item()返回一个普通的python浮点数。这会破坏autograd渐变链。

相关问题