我定义了一个GAN模型,我想用FID分数来评估它。我有1个通道的图像,这是mnist数据集,但这种方法需要3个通道的图像。我该怎么做来解决这个问题?
6vl6ewon1#
在评估之前尝试将其分成3个通道。
import torch import torchvision from torcheval import metrics # Load the MNIST dataset mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True) # Convert the 1 channel images to 3 channel images mnist_dataset.data = mnist_dataset.data.unsqueeze(1) mnist_dataset.data = mnist_dataset.data.repeat(1, 3, 1, 1) # Calculate the FID score fid_score = metrics.FrechetInceptionDistance()(mnist_dataset.data) # Evaluate the FID score print('FID score:', fid_score)`enter code here`
字符串
1条答案
按热度按时间6vl6ewon1#
在评估之前尝试将其分成3个通道。
字符串