在PyTorch中计算Fisher度量

jv4diomz  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(154)

给定一些简单的PyTorch模型,如何计算Fisher度量?
这里有一个(对实际用途无用的)平凡模型,它使用单个线性层来求解矩阵方程Ax=B,其中A是3x3矩阵,而x和b都是3x1列向量。给定A和B,x是多少?问题不重要。

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()

        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        out = self.linear(x)
        return out

# Define the training data
A = torch.tensor([[1., 2., 3.],
                  [4., 5., 6.],
                  [7., 8., 9.]])

b = torch.tensor([[52.],
                  [124.],
                  [196.]])

# Define the model and the optimizer
model = Net(input_dim=9, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = model(A.reshape(9))
    print(A@y_pred[:3])
    loss = nn.MSELoss(reduction='sum')(A@y_pred.view((3,1)), b)
    loss.backward()
    optimizer.step()

# Evaluate the model
with torch.no_grad():
    y_pred = model(A.reshape(9))
    print("Solution:\n", y_pred)

由此,我想计算模型的Fisher度量。我试图使用NNGeometry包,它需要一个数据加载器,所以我创建了另一个无用的琐碎片段,其中包含一个包含训练矩阵A的批处理:

class TrivialDataset(Dataset):
    def __init__(self):
        self.data = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).reshape(1,9)
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# Create the DataLoader
dataset = TrivialDataset()
loader = DataLoader(dataset, batch_size=1)

最后,我尝试生成FIM:

from nngeometry.metrics import FIM
from nngeometry.object import PMatDense
fisher_metric = FIM(model, loader, n_output=1, variant='regression', representation=PMatDense, device='cpu')

但是得到一个错误:

RuntimeError: shape '[9, 1]' is invalid for input of size 3

我可以看到问题来自一个必须制作的视图,但NNGeometry肯定能够处理输入维数大于输出维数的模型(例如在分类中)?
我能绕过这个吗?是否有一个很好的替代NNGeometry?

xxhby3vn

xxhby3vn1#

NNGeometry库期望模型输出形状为**(batch_size,n_output)的Tensor,但您的模型输出形状为(n_output,)的Tensor。因此,需要更改前向传递以获得NNGeometry**库期望的正确形状。

class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()

        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        out = self.linear(x)
        return out.view(-1, self.linear.out_features) ### Change from your forward pass. 

# Train the model
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = model(A.view(1, -1))  
    print(A @ y_pred[:3])
    loss = nn.MSELoss(reduction='sum')(A @ y_pred.view((3, 1)), b)
    loss.backward()
    optimizer.step()

相关问题