给定一些简单的PyTorch模型,如何计算Fisher度量?
这里有一个(对实际用途无用的)平凡模型,它使用单个线性层来求解矩阵方程Ax=B,其中A是3x3矩阵,而x和b都是3x1列向量。给定A和B,x是多少?问题不重要。
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self, input_dim, output_dim):
super(Net, self).__init__()
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, x):
out = self.linear(x)
return out
# Define the training data
A = torch.tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
b = torch.tensor([[52.],
[124.],
[196.]])
# Define the model and the optimizer
model = Net(input_dim=9, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Train the model
for epoch in range(2000):
optimizer.zero_grad()
y_pred = model(A.reshape(9))
print(A@y_pred[:3])
loss = nn.MSELoss(reduction='sum')(A@y_pred.view((3,1)), b)
loss.backward()
optimizer.step()
# Evaluate the model
with torch.no_grad():
y_pred = model(A.reshape(9))
print("Solution:\n", y_pred)
由此,我想计算模型的Fisher度量。我试图使用NNGeometry包,它需要一个数据加载器,所以我创建了另一个无用的琐碎片段,其中包含一个包含训练矩阵A的批处理:
class TrivialDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).reshape(1,9)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# Create the DataLoader
dataset = TrivialDataset()
loader = DataLoader(dataset, batch_size=1)
最后,我尝试生成FIM:
from nngeometry.metrics import FIM
from nngeometry.object import PMatDense
fisher_metric = FIM(model, loader, n_output=1, variant='regression', representation=PMatDense, device='cpu')
但是得到一个错误:
RuntimeError: shape '[9, 1]' is invalid for input of size 3
我可以看到问题来自一个必须制作的视图,但NNGeometry肯定能够处理输入维数大于输出维数的模型(例如在分类中)?
我能绕过这个吗?是否有一个很好的替代NNGeometry?
1条答案
按热度按时间xxhby3vn1#
NNGeometry库期望模型输出形状为**(batch_size,n_output)的Tensor,但您的模型输出形状为(n_output,)的Tensor。因此,需要更改前向传递以获得NNGeometry**库期望的正确形状。