如何在pytorch中使用预训练模型进行预测

mnemlml8  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(220)

我正在尝试用pytorch翻译这个https://github.com/piiswrong/deep3d/blob/master/deep3d.ipynb

import os
import numpy as np
import cv2
import torch
from PIL import Image
from images2gif import writeGif

net = torch.jit.load("deep3d_v1.0_640x360_cpu.pt")
net.eval()

shape = (384, 160)
img = cv2.imread('demo.jpg')
raw_shape = (img.shape[1], img.shape[0])
img = cv2.resize(img, shape)

X = img.astype(np.float32).transpose((2,0,1))
X = X.reshape((1,)+X.shape)
test_iter = np.array(X)
Y = net.predict(test_iter)

right = np.clip(Y.squeeze().transpose((1,2,0)), 0, 255).astype(np.uint8)
right = Image.fromarray(cv2.cvtColor(right, cv2.COLOR_BGR2RGB))
left = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
writeGif('demo.gif', [left, right], duration=0.08)

我得到了AttributeError: 'RecursiveScriptModule' object has no attribute 'predict'
我也试过net(test_iter),得到了RuntimeError: forward() Expected a value of type 'Tensor' for argument 'imgs' but instead found type 'ndarray'.
我做错了什么?

ckx4rj1h

ckx4rj1h1#

我会坚持使用net(test_iter)来预测,

Y = net(torch.from_numpy(test_iter))

因为这会将numpy.ndarray转换为Tensor,请参阅here
同样类似的问题:Similar Questions 1

相关问题